From 2320a6a8b5efc574b2169cf478b7ac7871abdd07 Mon Sep 17 00:00:00 2001 From: David Zheng Date: Mon, 9 Mar 2026 00:28:27 -0700 Subject: [PATCH 1/4] feat: defer activation qparam calculation to sequential epoch end Fixes #2446 Signed-off-by: dqzhengAP --- .../modifiers/quantization/calibration.py | 68 +++++++++++---- .../quantization/quantization/base.py | 21 ++++- src/llmcompressor/observers/base.py | 82 ++++++++++++++++++- 3 files changed, 153 insertions(+), 18 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index afc42c0ca6..a08277f803 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -17,6 +17,7 @@ from torch.nn import Module from llmcompressor.observers import Observer +from llmcompressor.observers.base import calibrate_module_from_observer __all__ = [ "initialize_observer", @@ -30,6 +31,7 @@ "calibrate_query_hook", "calibrate_key_hook", "calibrate_value_hook", + "flush_activation_qparams", ] @@ -156,7 +158,9 @@ def update_weight_zp_scale(module: Module): call_observer(module=module, base_name="weight") -def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): +def calibrate_activations( + module: Module, value: torch.Tensor, base_name: str, stats_only: bool = False +): """ Calibrate input or output activations by calling the a module's attached observer. @@ -164,7 +168,10 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): :param module: torch.nn.Module :param base_name: substring used to fetch the observer, scales, and zp :param value: torch.Tensor to be passed to the observer - + :param stats_only: if True, only update running statistics in the observer + (accumulate min/max) without computing or writing scale/zero_point. + Used during deferred qparam calibration — qparams are computed once + at epoch end via flush_activation_qparams instead of per batch. """ # If empty tensor, can't update zp/scale # Case for MoEs @@ -184,6 +191,12 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: calculate_gparam = True + # In deferred (stats_only) mode, only accumulate running min/max in the + # observer — skip writing scale/zero_point until epoch end. + if stats_only: + calculate_qparams = False + calculate_gparam = False + call_observer( module=module, base_name=base_name, @@ -196,43 +209,40 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): def calibrate_input_hook(module: Module, args: Any): """ Hook to calibrate input activations. - Will call the observers to update the scales/zp before applying - input QDQ in the module's forward pass. + Accumulates running min/max statistics in the observer without computing + scale/zero_point. Qparams are computed once at epoch end via + flush_activation_qparams (deferred mode). """ args = args[0] if isinstance(args, tuple) else args - calibrate_activations(module, value=args, base_name="input") + calibrate_activations(module, value=args, base_name="input", stats_only=True) def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): """ Hook to calibrate output activations. - Will call the observers to update the scales/zp before applying - output QDQ. + Accumulates running min/max statistics only (deferred qparam mode). + Qparams are computed at epoch end; forward_quantize is skipped during + calibration batches since quantization is disabled in the sequential pipeline. """ calibrate_activations( module, value=output, base_name="output", - ) - output = forward_quantize( - module=module, - value=output, - base_name="output", - args=module.quantization_scheme.output_activations, + stats_only=True, ) return output def calibrate_query_hook(module: Module, query_states: torch.Tensor): - calibrate_activations(module, query_states, base_name="q") + calibrate_activations(module, query_states, base_name="q", stats_only=True) def calibrate_key_hook(module: Module, key_states: torch.Tensor): - calibrate_activations(module, key_states, base_name="k") + calibrate_activations(module, key_states, base_name="k", stats_only=True) def calibrate_value_hook(module: Module, value_states: torch.Tensor): - calibrate_activations(module, value_states, base_name="v") + calibrate_activations(module, value_states, base_name="v", stats_only=True) def apply_calibration_status(module: Module): @@ -273,3 +283,29 @@ def reset_quantization_status(model: Module): for module in model.modules(): if hasattr(module, "quantization_status"): delattr(module, "quantization_status") + + +def flush_activation_qparams(module: Module): + """ + Compute and write final activation qparams from each observer's accumulated + running statistics, then free those statistics to reduce memory. + + This is called once at SEQUENTIAL_EPOCH_END for each subgraph, replacing the + per-batch qparam updates that were previously triggered by calibration hooks. + It is a no-op for modules with no quantization scheme or no activation observers. + + Note: weight observers are not touched here — weight qparams are always computed + up-front in ``on_start`` via ``update_weight_zp_scale``. + + apply to targeted modules with: + for _, module in match_named_modules(...): + flush_activation_qparams(module) + + :param module: module to flush activation qparams for + """ + scheme = getattr(module, "quantization_scheme", None) + if scheme is None: + return + + for base_name in ("input", "output", "q", "k", "v"): + calibrate_module_from_observer(module, base_name) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index ffd64377f7..37a720ee70 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -4,6 +4,7 @@ from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization.calibration import ( + flush_activation_qparams, update_weight_global_scale, update_weight_zp_scale, ) @@ -65,7 +66,10 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): """ - Begin calibrating activations and weights. Calibrate weights only once on start + Begin calibrating activations and weights. Calibrate weights only once on start. + Quantization is kept DISABLED during calibration batches so that forward passes + run in fp32. Activation qparams are computed once per subgraph at + SEQUENTIAL_EPOCH_END via flush_activation_qparams (deferred mode). """ self.started_ = True QuantizationMixin.start_calibration(self, state.model) @@ -90,11 +94,26 @@ def on_start(self, state: State, event: Event, **kwargs): for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"): update_weight_zp_scale(module) + # Disable quantization during calibration batches so that fp32 activations + # flow through the model unmodified while hooks accumulate running stats. + # Re-enable once after epoch end when qparams have been flushed. + from compressed_tensors.quantization import disable_quantization + + state.model.apply(disable_quantization) + def on_event(self, state: State, event: Event, **kwargs): if event.type_ == EventType.CALIBRATION_EPOCH_START: if not self.started_: self.on_start(state, None) + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + # Deferred qparam flush: compute scale/zero_point from accumulated + # running statistics, then free those stats to reduce memory. + for _, module in match_named_modules( + state.model, self.resolved_targets, self.ignore + ): + flush_activation_qparams(module) + if event.type_ == EventType.CALIBRATION_EPOCH_END: if not self.ended_: self.on_end(state, None) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 384bbf6ead..5fe3bee846 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -11,7 +11,7 @@ from llmcompressor.observers.helpers import flatten_for_calibration -__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple"] +__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple", "calibrate_module_from_observer"] MinMaxTuple = Tuple[torch.Tensor, torch.Tensor] ScaleZpTuple = Tuple[torch.Tensor, torch.Tensor] @@ -74,6 +74,39 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ raise NotImplementedError() + def get_accumulated_min_max(self) -> Optional[MinMaxTuple]: + """ + Return the accumulated running min/max statistics stored by this observer, + without performing any new observation. Returns None if no statistics have + been accumulated yet (i.e. no batches have been seen). + + Subclasses which accumulate state (StaticMinMax, MovingAverage) naturally + expose this through their ``past_min_vals`` / ``past_max_vals`` attributes. + Memoryless observers have no running state, so this always returns None. + + :return: (min_vals, max_vals) tensors or None + """ + min_vals = getattr(self, "past_min_vals", None) + max_vals = getattr(self, "past_max_vals", None) + if min_vals is None or max_vals is None: + return None + return min_vals, max_vals + + def clear_accumulated_stats(self): + """ + Delete accumulated running statistics to free memory after qparams have been + computed and written to the parent module. Only clears attributes that exist + on the observer (memoryless observers are unaffected). + """ + for attr in ( + "past_min_vals", + "past_max_vals", + "past_global_min_vals", + "past_global_max_vals", + ): + if hasattr(self, attr): + delattr(self, attr) + @torch.no_grad def forward(self, observed: torch.Tensor) -> ScaleZpTuple: """ @@ -142,3 +175,50 @@ def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]): "Cannot compute scale and zero points " "without first computing global scale" ) + + +@torch.no_grad() +def calibrate_module_from_observer( + module: torch.nn.Module, + base_name: str, +) -> bool: + """ + Flush an observer's accumulated running statistics into the parent module's + quantization parameters (scale / zero_point), then free the running stats. + + This is the deferred counterpart to ``call_observer``. Instead of accepting a + fresh activation tensor, it reads the min/max values that the observer has + already accumulated across all calibration batches and computes qparams from + those final statistics. + + :param module: module whose ``{base_name}_observer`` attribute holds the observer + :param base_name: one of "input", "output", "q", "k", "v" + :return: True if qparams were updated, False if observer had no accumulated stats + """ + from compressed_tensors.utils import align_module_device, update_offload_parameter + + observer: Optional[Observer] = getattr(module, f"{base_name}_observer", None) + if observer is None: + return False + + accumulated = observer.get_accumulated_min_max() + if accumulated is None: + return False + + min_vals, max_vals = accumulated + global_scale = getattr(module, f"{base_name}_global_scale", None) + + with align_module_device(module): + scales, zero_points = calculate_qparams( + min_vals=min_vals, + max_vals=max_vals, + quantization_args=observer.args, + global_scale=global_scale, + ) + update_offload_parameter(module, f"{base_name}_scale", scales) + if hasattr(module, f"{base_name}_zero_point"): + update_offload_parameter(module, f"{base_name}_zero_point", zero_points) + + # Free memory — running stats no longer needed + observer.clear_accumulated_stats() + return True From 26c29adc89e9de2de70a06cdb233c86221e5f34a Mon Sep 17 00:00:00 2001 From: David Zheng Date: Mon, 9 Mar 2026 01:25:26 -0700 Subject: [PATCH 2/4] fix: use update_deferred_stats for all observer types MemorylessMinMaxObserver has no past_min_vals, so get_accumulated_min_max() always returned None, causing scale to remain 0. Fix: add update_deferred_stats() to Observer base class which maintains _deferred_min/_deferred_max independently of subclass implementation. calibrate_activations(stats_only=True) now calls this instead of observer(value). Local validation on opt-125m (CPU, 32 calibration samples): - 72/72 modules have input_scale - Perplexity: 28.86 (FP32) -> 30.78 (INT8), 6.7% degradation - No observer stats leaked after calibration Signed-off-by: dqzhengAP --- .../modifiers/quantization/calibration.py | 15 ++++++-- src/llmcompressor/observers/base.py | 38 ++++++++++++++----- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index a08277f803..a411b1b296 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -191,11 +191,18 @@ def calibrate_activations( if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: calculate_gparam = True - # In deferred (stats_only) mode, only accumulate running min/max in the - # observer — skip writing scale/zero_point until epoch end. + # In deferred (stats_only) mode: call the observer to accumulate running + # min/max stats but do NOT write scale/zero_point yet. + # Qparams are written once at epoch end via flush_activation_qparams. if stats_only: - calculate_qparams = False - calculate_gparam = False + # Deferred mode: accumulate global min/max into the observer's + # _deferred_min / _deferred_max. Works for ALL observer types, + # including MemorylessMinMaxObserver which has no past_min_vals. + # Qparams are written once at epoch end via flush_activation_qparams. + observer = getattr(module, f"{base_name}_observer", None) + if observer is not None: + observer.update_deferred_stats(value) + return call_observer( module=module, diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 5fe3bee846..26daf6016c 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -74,20 +74,37 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ raise NotImplementedError() + def update_deferred_stats(self, observed: torch.Tensor): + """ + Accumulate global min/max from an observed tensor into ``_deferred_min`` + and ``_deferred_max`` on this observer. + + Called by ``calibrate_activations`` in ``stats_only`` mode for ALL observer + types including ``MemorylessMinMaxObserver`` which has no ``past_min_vals``. + + :param observed: activation tensor for this batch + """ + batch_min = observed.float().min() + batch_max = observed.float().max() + + if not hasattr(self, "_deferred_min") or self._deferred_min is None: + self._deferred_min = batch_min + self._deferred_max = batch_max + else: + self._deferred_min = torch.min(self._deferred_min, batch_min) + self._deferred_max = torch.max(self._deferred_max, batch_max) + def get_accumulated_min_max(self) -> Optional[MinMaxTuple]: """ - Return the accumulated running min/max statistics stored by this observer, - without performing any new observation. Returns None if no statistics have - been accumulated yet (i.e. no batches have been seen). + Return accumulated min/max populated by ``update_deferred_stats``. + Returns None if no batches have been seen yet. - Subclasses which accumulate state (StaticMinMax, MovingAverage) naturally - expose this through their ``past_min_vals`` / ``past_max_vals`` attributes. - Memoryless observers have no running state, so this always returns None. + Works for all observer types including ``MemorylessMinMaxObserver``. :return: (min_vals, max_vals) tensors or None """ - min_vals = getattr(self, "past_min_vals", None) - max_vals = getattr(self, "past_max_vals", None) + min_vals = getattr(self, "_deferred_min", None) + max_vals = getattr(self, "_deferred_max", None) if min_vals is None or max_vals is None: return None return min_vals, max_vals @@ -95,10 +112,11 @@ def get_accumulated_min_max(self) -> Optional[MinMaxTuple]: def clear_accumulated_stats(self): """ Delete accumulated running statistics to free memory after qparams have been - computed and written to the parent module. Only clears attributes that exist - on the observer (memoryless observers are unaffected). + computed and written to the parent module. """ for attr in ( + "_deferred_min", + "_deferred_max", "past_min_vals", "past_max_vals", "past_global_min_vals", From 316114ad54face69f74ddbf1b1eb7318f5f8876b Mon Sep 17 00:00:00 2001 From: dqzhengAP Date: Sat, 14 Mar 2026 21:06:21 -0700 Subject: [PATCH 3/4] refactor: always disable quantization during calibration, re-enable for propagation - pipeline.py: remove disable_qac / DISABLE_QAC_MODIFIERS conditional logic; quantization is now unconditionally disabled during calibration pass and re-enabled during propagation pass so downstream subgraphs receive quantized inputs - quantization/base.py: remove erroneous disable_quantization call from on_start; control now lives entirely in pipeline layer - observers/base.py: move update_offload_parameter to top-level import - calibration.py: fix hook docstrings to accurately describe stats-only behavior Signed-off-by: dqzhengAP --- .../modifiers/quantization/calibration.py | 19 +++++-------- .../quantization/quantization/base.py | 16 +++-------- src/llmcompressor/observers/base.py | 5 +--- .../pipelines/sequential/pipeline.py | 27 +++++++++---------- 4 files changed, 25 insertions(+), 42 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index a411b1b296..a847f8d0d8 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -195,10 +195,6 @@ def calibrate_activations( # min/max stats but do NOT write scale/zero_point yet. # Qparams are written once at epoch end via flush_activation_qparams. if stats_only: - # Deferred mode: accumulate global min/max into the observer's - # _deferred_min / _deferred_max. Works for ALL observer types, - # including MemorylessMinMaxObserver which has no past_min_vals. - # Qparams are written once at epoch end via flush_activation_qparams. observer = getattr(module, f"{base_name}_observer", None) if observer is not None: observer.update_deferred_stats(value) @@ -215,10 +211,9 @@ def calibrate_activations( def calibrate_input_hook(module: Module, args: Any): """ - Hook to calibrate input activations. - Accumulates running min/max statistics in the observer without computing - scale/zero_point. Qparams are computed once at epoch end via - flush_activation_qparams (deferred mode). + Hook to accumulate input activation statistics (min/max) in the observer. + Scale and zero_point are not written here; they are computed once per subgraph + at epoch end via flush_activation_qparams. """ args = args[0] if isinstance(args, tuple) else args calibrate_activations(module, value=args, base_name="input", stats_only=True) @@ -226,10 +221,10 @@ def calibrate_input_hook(module: Module, args: Any): def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): """ - Hook to calibrate output activations. - Accumulates running min/max statistics only (deferred qparam mode). - Qparams are computed at epoch end; forward_quantize is skipped during - calibration batches since quantization is disabled in the sequential pipeline. + Hook to accumulate output activation statistics (min/max) in the observer. + Scale and zero_point are not written here; they are computed once per subgraph + at epoch end via flush_activation_qparams. + Note: forward_quantize is intentionally absent — hooks only collect statistics. """ calibrate_activations( module, diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 37a720ee70..fd1583ab04 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -67,9 +67,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): """ Begin calibrating activations and weights. Calibrate weights only once on start. - Quantization is kept DISABLED during calibration batches so that forward passes - run in fp32. Activation qparams are computed once per subgraph at - SEQUENTIAL_EPOCH_END via flush_activation_qparams (deferred mode). + Activation qparams are computed once per subgraph at SEQUENTIAL_EPOCH_END via + flush_activation_qparams, rather than per batch. """ self.started_ = True QuantizationMixin.start_calibration(self, state.model) @@ -94,21 +93,14 @@ def on_start(self, state: State, event: Event, **kwargs): for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"): update_weight_zp_scale(module) - # Disable quantization during calibration batches so that fp32 activations - # flow through the model unmodified while hooks accumulate running stats. - # Re-enable once after epoch end when qparams have been flushed. - from compressed_tensors.quantization import disable_quantization - - state.model.apply(disable_quantization) - def on_event(self, state: State, event: Event, **kwargs): if event.type_ == EventType.CALIBRATION_EPOCH_START: if not self.started_: self.on_start(state, None) if event.type_ == EventType.SEQUENTIAL_EPOCH_END: - # Deferred qparam flush: compute scale/zero_point from accumulated - # running statistics, then free those stats to reduce memory. + # Compute scale/zero_point once from accumulated running statistics, + # then free those stats to reduce memory. for _, module in match_named_modules( state.model, self.resolved_targets, self.ignore ): diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 26daf6016c..3adf6a5c0c 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -7,8 +7,7 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam from compressed_tensors.registry.registry import RegistryMixin -from compressed_tensors.utils import align_module_device - +from compressed_tensors.utils import align_module_device, update_offload_parameter from llmcompressor.observers.helpers import flatten_for_calibration __all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple", "calibrate_module_from_observer"] @@ -213,8 +212,6 @@ def calibrate_module_from_observer( :param base_name: one of "input", "output", "q", "k", "v" :return: True if qparams were updated, False if observer had no accumulated stats """ - from compressed_tensors.utils import align_module_device, update_offload_parameter - observer: Optional[Observer] = getattr(module, f"{base_name}_observer", None) if observer is None: return False diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index a16693b1a0..2b3bb5f778 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Iterator import torch +from compressed_tensors.quantization import disable_quantization, enable_quantization from compressed_tensors.utils import disable_offloading from torch.utils.data.dataloader import DataLoader from tqdm import tqdm @@ -19,7 +20,6 @@ ) from llmcompressor.utils.dev import get_main_device from llmcompressor.utils.helpers import ( - DISABLE_QAC_MODIFIERS, DisableQuantization, calibration_forward_context, ) @@ -111,18 +111,13 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - # TODO: remove this to enable quantization aware calibration - # for GPTQ, AWQ and AutoRound. - disable_qac = any( - type(mod).__name__ in DISABLE_QAC_MODIFIERS - for mod in session.lifecycle.recipe.modifiers - ) - with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - # Optionally disable quantization - if not dataset_args.quantization_aware_calibration or disable_qac: - stack.enter_context(DisableQuantization(model)) + # Always disable quantization during calibration so that observer hooks + # accumulate statistics from unquantized activations. Quantization is + # re-enabled during the propagation pass so that downstream subgraphs + # receive realistic (quantized) inputs. + stack.enter_context(DisableQuantization(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader( @@ -148,7 +143,7 @@ def __call__( num_batches = len(dataloader) use_prefetch = getattr(dataset_args, "sequential_prefetch", False) with disable_offloading(): - # do a preliminary pass to trigger modifier hooks + # calibration pass: hooks accumulate activation statistics for batch_idx, inputs in _get_batches( activations, num_batches, @@ -159,10 +154,13 @@ def __call__( session.state.current_batch_idx = batch_idx subgraph.forward(model, **inputs) + # flush accumulated stats -> write scale/zero_point once per subgraph LifecycleCallbacks.sequential_epoch_end(subgraph) - # this pass does not trigger modifier hooks - # and is only used for capturing outputs of newly compressed modules + # propagation pass: modifier hooks are disabled but quantization is + # re-enabled so that compressed module outputs are quantized. + # This ensures downstream subgraphs receive realistic inputs. + model.apply(enable_quantization) with HooksMixin.disable_hooks(): for batch_idx, inputs in _get_batches( activations, @@ -175,6 +173,7 @@ def __call__( if subgraph_index < num_subgraphs - 1: activations.update(batch_idx, output) activations.delete(batch_idx, subgraph.consumed_names) + model.apply(disable_quantization) # redundant, finish any remaining compression LifecycleCallbacks.calibration_epoch_end() From 684d71048f637bb0e84a8a9f798f50c996765f83 Mon Sep 17 00:00:00 2001 From: dqzhengAP Date: Mon, 16 Mar 2026 16:37:46 -0700 Subject: [PATCH 4/4] refactor: address HDCharles review comments - rename flush_activation_qparams -> write_activation_qparams - rename calibrate_module_from_observer -> update_module_qparams_from_observer - extract ACTIVATION_BASE_NAMES constant in calibration.py - move SEQUENTIAL_EPOCH_END docstring note from on_start to on_event - use ExitStack for propagation pass quantization management - update observer.forward() to accumulate stats alongside computing qparams Signed-off-by: dqzhengAP --- .../modifiers/quantization/calibration.py | 23 +++++++++++-------- .../quantization/quantization/base.py | 11 ++++----- src/llmcompressor/observers/base.py | 12 ++++++---- .../pipelines/sequential/pipeline.py | 6 +++-- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index a847f8d0d8..70fff72845 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -17,7 +17,7 @@ from torch.nn import Module from llmcompressor.observers import Observer -from llmcompressor.observers.base import calibrate_module_from_observer +from llmcompressor.observers.base import update_module_qparams_from_observer __all__ = [ "initialize_observer", @@ -31,9 +31,12 @@ "calibrate_query_hook", "calibrate_key_hook", "calibrate_value_hook", - "flush_activation_qparams", + "write_activation_qparams", ] +# Activation observer base names used across calibration and quantization code +ACTIVATION_BASE_NAMES = ("input", "output", "q", "k", "v") + def initialize_observer( module: Module, @@ -171,7 +174,7 @@ def calibrate_activations( :param stats_only: if True, only update running statistics in the observer (accumulate min/max) without computing or writing scale/zero_point. Used during deferred qparam calibration — qparams are computed once - at epoch end via flush_activation_qparams instead of per batch. + at epoch end via write_activation_qparams instead of per batch. """ # If empty tensor, can't update zp/scale # Case for MoEs @@ -193,7 +196,7 @@ def calibrate_activations( # In deferred (stats_only) mode: call the observer to accumulate running # min/max stats but do NOT write scale/zero_point yet. - # Qparams are written once at epoch end via flush_activation_qparams. + # Qparams are written once at epoch end via write_activation_qparams. if stats_only: observer = getattr(module, f"{base_name}_observer", None) if observer is not None: @@ -213,7 +216,7 @@ def calibrate_input_hook(module: Module, args: Any): """ Hook to accumulate input activation statistics (min/max) in the observer. Scale and zero_point are not written here; they are computed once per subgraph - at epoch end via flush_activation_qparams. + at epoch end via write_activation_qparams. """ args = args[0] if isinstance(args, tuple) else args calibrate_activations(module, value=args, base_name="input", stats_only=True) @@ -223,7 +226,7 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): """ Hook to accumulate output activation statistics (min/max) in the observer. Scale and zero_point are not written here; they are computed once per subgraph - at epoch end via flush_activation_qparams. + at epoch end via write_activation_qparams. Note: forward_quantize is intentionally absent — hooks only collect statistics. """ calibrate_activations( @@ -287,7 +290,7 @@ def reset_quantization_status(model: Module): delattr(module, "quantization_status") -def flush_activation_qparams(module: Module): +def write_activation_qparams(module: Module): """ Compute and write final activation qparams from each observer's accumulated running statistics, then free those statistics to reduce memory. @@ -301,7 +304,7 @@ def flush_activation_qparams(module: Module): apply to targeted modules with: for _, module in match_named_modules(...): - flush_activation_qparams(module) + write_activation_qparams(module) :param module: module to flush activation qparams for """ @@ -309,5 +312,5 @@ def flush_activation_qparams(module: Module): if scheme is None: return - for base_name in ("input", "output", "q", "k", "v"): - calibrate_module_from_observer(module, base_name) + for base_name in ACTIVATION_BASE_NAMES: + update_module_qparams_from_observer(module, base_name) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index fd1583ab04..498e961c3e 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -4,7 +4,7 @@ from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization.calibration import ( - flush_activation_qparams, + write_activation_qparams, update_weight_global_scale, update_weight_zp_scale, ) @@ -67,8 +67,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): """ Begin calibrating activations and weights. Calibrate weights only once on start. - Activation qparams are computed once per subgraph at SEQUENTIAL_EPOCH_END via - flush_activation_qparams, rather than per batch. """ self.started_ = True QuantizationMixin.start_calibration(self, state.model) @@ -99,12 +97,13 @@ def on_event(self, state: State, event: Event, **kwargs): self.on_start(state, None) if event.type_ == EventType.SEQUENTIAL_EPOCH_END: - # Compute scale/zero_point once from accumulated running statistics, - # then free those stats to reduce memory. + # Activation qparams are computed once per subgraph at SEQUENTIAL_EPOCH_END + # from accumulated running statistics, rather than per batch. + # Running statistics are freed after qparams are written to reduce memory. for _, module in match_named_modules( state.model, self.resolved_targets, self.ignore ): - flush_activation_qparams(module) + write_activation_qparams(module) if event.type_ == EventType.CALIBRATION_EPOCH_END: if not self.ended_: diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 3adf6a5c0c..2c8cab30dd 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -10,7 +10,7 @@ from compressed_tensors.utils import align_module_device, update_offload_parameter from llmcompressor.observers.helpers import flatten_for_calibration -__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple", "calibrate_module_from_observer"] +__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple", "update_module_qparams_from_observer"] MinMaxTuple = Tuple[torch.Tensor, torch.Tensor] ScaleZpTuple = Tuple[torch.Tensor, torch.Tensor] @@ -127,12 +127,14 @@ def clear_accumulated_stats(self): @torch.no_grad def forward(self, observed: torch.Tensor) -> ScaleZpTuple: """ - Calculate updated scales and zero points from observed value - (weight, activation, or attention state). + Accumulate running statistics from the observed value and update + deferred min/max. Qparams (scale/zero_point) are not computed here; + they are written once at epoch end via update_module_qparams_from_observer. :param observed: value being observed - :return: calibrated scale and zero point + :return: calibrated scale and zero point (from accumulated stats) """ + self.update_deferred_stats(observed) scales, zero_points, _min, _max = self._forward_with_minmax(observed) return (scales, zero_points) @@ -195,7 +197,7 @@ def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]): @torch.no_grad() -def calibrate_module_from_observer( +def update_module_qparams_from_observer( module: torch.nn.Module, base_name: str, ) -> bool: diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 2b3bb5f778..1442228f0a 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -160,8 +160,9 @@ def __call__( # propagation pass: modifier hooks are disabled but quantization is # re-enabled so that compressed module outputs are quantized. # This ensures downstream subgraphs receive realistic inputs. - model.apply(enable_quantization) - with HooksMixin.disable_hooks(): + with contextlib.ExitStack() as prop_stack: + prop_stack.enter_context(HooksMixin.disable_hooks()) + model.apply(enable_quantization) for batch_idx, inputs in _get_batches( activations, num_batches, @@ -173,6 +174,7 @@ def __call__( if subgraph_index < num_subgraphs - 1: activations.update(batch_idx, output) activations.delete(batch_idx, subgraph.consumed_names) + # restore disabled quantization for next calibration pass model.apply(disable_quantization) # redundant, finish any remaining compression