diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index afc42c0ca6..70fff72845 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -17,6 +17,7 @@ from torch.nn import Module from llmcompressor.observers import Observer +from llmcompressor.observers.base import update_module_qparams_from_observer __all__ = [ "initialize_observer", @@ -30,8 +31,12 @@ "calibrate_query_hook", "calibrate_key_hook", "calibrate_value_hook", + "write_activation_qparams", ] +# Activation observer base names used across calibration and quantization code +ACTIVATION_BASE_NAMES = ("input", "output", "q", "k", "v") + def initialize_observer( module: Module, @@ -156,7 +161,9 @@ def update_weight_zp_scale(module: Module): call_observer(module=module, base_name="weight") -def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): +def calibrate_activations( + module: Module, value: torch.Tensor, base_name: str, stats_only: bool = False +): """ Calibrate input or output activations by calling the a module's attached observer. @@ -164,7 +171,10 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): :param module: torch.nn.Module :param base_name: substring used to fetch the observer, scales, and zp :param value: torch.Tensor to be passed to the observer - + :param stats_only: if True, only update running statistics in the observer + (accumulate min/max) without computing or writing scale/zero_point. + Used during deferred qparam calibration — qparams are computed once + at epoch end via write_activation_qparams instead of per batch. """ # If empty tensor, can't update zp/scale # Case for MoEs @@ -184,6 +194,15 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: calculate_gparam = True + # In deferred (stats_only) mode: call the observer to accumulate running + # min/max stats but do NOT write scale/zero_point yet. + # Qparams are written once at epoch end via write_activation_qparams. + if stats_only: + observer = getattr(module, f"{base_name}_observer", None) + if observer is not None: + observer.update_deferred_stats(value) + return + call_observer( module=module, base_name=base_name, @@ -195,44 +214,40 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): def calibrate_input_hook(module: Module, args: Any): """ - Hook to calibrate input activations. - Will call the observers to update the scales/zp before applying - input QDQ in the module's forward pass. + Hook to accumulate input activation statistics (min/max) in the observer. + Scale and zero_point are not written here; they are computed once per subgraph + at epoch end via write_activation_qparams. """ args = args[0] if isinstance(args, tuple) else args - calibrate_activations(module, value=args, base_name="input") + calibrate_activations(module, value=args, base_name="input", stats_only=True) def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): """ - Hook to calibrate output activations. - Will call the observers to update the scales/zp before applying - output QDQ. + Hook to accumulate output activation statistics (min/max) in the observer. + Scale and zero_point are not written here; they are computed once per subgraph + at epoch end via write_activation_qparams. + Note: forward_quantize is intentionally absent — hooks only collect statistics. """ calibrate_activations( module, value=output, base_name="output", - ) - output = forward_quantize( - module=module, - value=output, - base_name="output", - args=module.quantization_scheme.output_activations, + stats_only=True, ) return output def calibrate_query_hook(module: Module, query_states: torch.Tensor): - calibrate_activations(module, query_states, base_name="q") + calibrate_activations(module, query_states, base_name="q", stats_only=True) def calibrate_key_hook(module: Module, key_states: torch.Tensor): - calibrate_activations(module, key_states, base_name="k") + calibrate_activations(module, key_states, base_name="k", stats_only=True) def calibrate_value_hook(module: Module, value_states: torch.Tensor): - calibrate_activations(module, value_states, base_name="v") + calibrate_activations(module, value_states, base_name="v", stats_only=True) def apply_calibration_status(module: Module): @@ -273,3 +288,29 @@ def reset_quantization_status(model: Module): for module in model.modules(): if hasattr(module, "quantization_status"): delattr(module, "quantization_status") + + +def write_activation_qparams(module: Module): + """ + Compute and write final activation qparams from each observer's accumulated + running statistics, then free those statistics to reduce memory. + + This is called once at SEQUENTIAL_EPOCH_END for each subgraph, replacing the + per-batch qparam updates that were previously triggered by calibration hooks. + It is a no-op for modules with no quantization scheme or no activation observers. + + Note: weight observers are not touched here — weight qparams are always computed + up-front in ``on_start`` via ``update_weight_zp_scale``. + + apply to targeted modules with: + for _, module in match_named_modules(...): + write_activation_qparams(module) + + :param module: module to flush activation qparams for + """ + scheme = getattr(module, "quantization_scheme", None) + if scheme is None: + return + + for base_name in ACTIVATION_BASE_NAMES: + update_module_qparams_from_observer(module, base_name) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index ffd64377f7..498e961c3e 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -4,6 +4,7 @@ from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization.calibration import ( + write_activation_qparams, update_weight_global_scale, update_weight_zp_scale, ) @@ -65,7 +66,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): """ - Begin calibrating activations and weights. Calibrate weights only once on start + Begin calibrating activations and weights. Calibrate weights only once on start. """ self.started_ = True QuantizationMixin.start_calibration(self, state.model) @@ -95,6 +96,15 @@ def on_event(self, state: State, event: Event, **kwargs): if not self.started_: self.on_start(state, None) + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + # Activation qparams are computed once per subgraph at SEQUENTIAL_EPOCH_END + # from accumulated running statistics, rather than per batch. + # Running statistics are freed after qparams are written to reduce memory. + for _, module in match_named_modules( + state.model, self.resolved_targets, self.ignore + ): + write_activation_qparams(module) + if event.type_ == EventType.CALIBRATION_EPOCH_END: if not self.ended_: self.on_end(state, None) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 384bbf6ead..2c8cab30dd 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -7,11 +7,10 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam from compressed_tensors.registry.registry import RegistryMixin -from compressed_tensors.utils import align_module_device - +from compressed_tensors.utils import align_module_device, update_offload_parameter from llmcompressor.observers.helpers import flatten_for_calibration -__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple"] +__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple", "update_module_qparams_from_observer"] MinMaxTuple = Tuple[torch.Tensor, torch.Tensor] ScaleZpTuple = Tuple[torch.Tensor, torch.Tensor] @@ -74,15 +73,68 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple: """ raise NotImplementedError() + def update_deferred_stats(self, observed: torch.Tensor): + """ + Accumulate global min/max from an observed tensor into ``_deferred_min`` + and ``_deferred_max`` on this observer. + + Called by ``calibrate_activations`` in ``stats_only`` mode for ALL observer + types including ``MemorylessMinMaxObserver`` which has no ``past_min_vals``. + + :param observed: activation tensor for this batch + """ + batch_min = observed.float().min() + batch_max = observed.float().max() + + if not hasattr(self, "_deferred_min") or self._deferred_min is None: + self._deferred_min = batch_min + self._deferred_max = batch_max + else: + self._deferred_min = torch.min(self._deferred_min, batch_min) + self._deferred_max = torch.max(self._deferred_max, batch_max) + + def get_accumulated_min_max(self) -> Optional[MinMaxTuple]: + """ + Return accumulated min/max populated by ``update_deferred_stats``. + Returns None if no batches have been seen yet. + + Works for all observer types including ``MemorylessMinMaxObserver``. + + :return: (min_vals, max_vals) tensors or None + """ + min_vals = getattr(self, "_deferred_min", None) + max_vals = getattr(self, "_deferred_max", None) + if min_vals is None or max_vals is None: + return None + return min_vals, max_vals + + def clear_accumulated_stats(self): + """ + Delete accumulated running statistics to free memory after qparams have been + computed and written to the parent module. + """ + for attr in ( + "_deferred_min", + "_deferred_max", + "past_min_vals", + "past_max_vals", + "past_global_min_vals", + "past_global_max_vals", + ): + if hasattr(self, attr): + delattr(self, attr) + @torch.no_grad def forward(self, observed: torch.Tensor) -> ScaleZpTuple: """ - Calculate updated scales and zero points from observed value - (weight, activation, or attention state). + Accumulate running statistics from the observed value and update + deferred min/max. Qparams (scale/zero_point) are not computed here; + they are written once at epoch end via update_module_qparams_from_observer. :param observed: value being observed - :return: calibrated scale and zero point + :return: calibrated scale and zero point (from accumulated stats) """ + self.update_deferred_stats(observed) scales, zero_points, _min, _max = self._forward_with_minmax(observed) return (scales, zero_points) @@ -142,3 +194,48 @@ def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]): "Cannot compute scale and zero points " "without first computing global scale" ) + + +@torch.no_grad() +def update_module_qparams_from_observer( + module: torch.nn.Module, + base_name: str, +) -> bool: + """ + Flush an observer's accumulated running statistics into the parent module's + quantization parameters (scale / zero_point), then free the running stats. + + This is the deferred counterpart to ``call_observer``. Instead of accepting a + fresh activation tensor, it reads the min/max values that the observer has + already accumulated across all calibration batches and computes qparams from + those final statistics. + + :param module: module whose ``{base_name}_observer`` attribute holds the observer + :param base_name: one of "input", "output", "q", "k", "v" + :return: True if qparams were updated, False if observer had no accumulated stats + """ + observer: Optional[Observer] = getattr(module, f"{base_name}_observer", None) + if observer is None: + return False + + accumulated = observer.get_accumulated_min_max() + if accumulated is None: + return False + + min_vals, max_vals = accumulated + global_scale = getattr(module, f"{base_name}_global_scale", None) + + with align_module_device(module): + scales, zero_points = calculate_qparams( + min_vals=min_vals, + max_vals=max_vals, + quantization_args=observer.args, + global_scale=global_scale, + ) + update_offload_parameter(module, f"{base_name}_scale", scales) + if hasattr(module, f"{base_name}_zero_point"): + update_offload_parameter(module, f"{base_name}_zero_point", zero_points) + + # Free memory — running stats no longer needed + observer.clear_accumulated_stats() + return True diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index a16693b1a0..1442228f0a 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Iterator import torch +from compressed_tensors.quantization import disable_quantization, enable_quantization from compressed_tensors.utils import disable_offloading from torch.utils.data.dataloader import DataLoader from tqdm import tqdm @@ -19,7 +20,6 @@ ) from llmcompressor.utils.dev import get_main_device from llmcompressor.utils.helpers import ( - DISABLE_QAC_MODIFIERS, DisableQuantization, calibration_forward_context, ) @@ -111,18 +111,13 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - # TODO: remove this to enable quantization aware calibration - # for GPTQ, AWQ and AutoRound. - disable_qac = any( - type(mod).__name__ in DISABLE_QAC_MODIFIERS - for mod in session.lifecycle.recipe.modifiers - ) - with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - # Optionally disable quantization - if not dataset_args.quantization_aware_calibration or disable_qac: - stack.enter_context(DisableQuantization(model)) + # Always disable quantization during calibration so that observer hooks + # accumulate statistics from unquantized activations. Quantization is + # re-enabled during the propagation pass so that downstream subgraphs + # receive realistic (quantized) inputs. + stack.enter_context(DisableQuantization(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader( @@ -148,7 +143,7 @@ def __call__( num_batches = len(dataloader) use_prefetch = getattr(dataset_args, "sequential_prefetch", False) with disable_offloading(): - # do a preliminary pass to trigger modifier hooks + # calibration pass: hooks accumulate activation statistics for batch_idx, inputs in _get_batches( activations, num_batches, @@ -159,11 +154,15 @@ def __call__( session.state.current_batch_idx = batch_idx subgraph.forward(model, **inputs) + # flush accumulated stats -> write scale/zero_point once per subgraph LifecycleCallbacks.sequential_epoch_end(subgraph) - # this pass does not trigger modifier hooks - # and is only used for capturing outputs of newly compressed modules - with HooksMixin.disable_hooks(): + # propagation pass: modifier hooks are disabled but quantization is + # re-enabled so that compressed module outputs are quantized. + # This ensures downstream subgraphs receive realistic inputs. + with contextlib.ExitStack() as prop_stack: + prop_stack.enter_context(HooksMixin.disable_hooks()) + model.apply(enable_quantization) for batch_idx, inputs in _get_batches( activations, num_batches, @@ -175,6 +174,8 @@ def __call__( if subgraph_index < num_subgraphs - 1: activations.update(batch_idx, output) activations.delete(batch_idx, subgraph.consumed_names) + # restore disabled quantization for next calibration pass + model.apply(disable_quantization) # redundant, finish any remaining compression LifecycleCallbacks.calibration_epoch_end()