Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 56 additions & 18 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.nn import Module

from llmcompressor.observers import Observer
from llmcompressor.observers.base import calibrate_module_from_observer

__all__ = [
"initialize_observer",
Expand All @@ -30,6 +31,7 @@
"calibrate_query_hook",
"calibrate_key_hook",
"calibrate_value_hook",
"flush_activation_qparams",
]


Expand Down Expand Up @@ -156,15 +158,20 @@ def update_weight_zp_scale(module: Module):
call_observer(module=module, base_name="weight")


def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
def calibrate_activations(
module: Module, value: torch.Tensor, base_name: str, stats_only: bool = False
):
"""
Calibrate input or output activations by calling the a module's attached
observer.

:param module: torch.nn.Module
:param base_name: substring used to fetch the observer, scales, and zp
:param value: torch.Tensor to be passed to the observer

:param stats_only: if True, only update running statistics in the observer
(accumulate min/max) without computing or writing scale/zero_point.
Used during deferred qparam calibration — qparams are computed once
at epoch end via flush_activation_qparams instead of per batch.
"""
# If empty tensor, can't update zp/scale
# Case for MoEs
Expand All @@ -184,6 +191,15 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
calculate_gparam = True

# In deferred (stats_only) mode: call the observer to accumulate running
# min/max stats but do NOT write scale/zero_point yet.
# Qparams are written once at epoch end via flush_activation_qparams.
if stats_only:
observer = getattr(module, f"{base_name}_observer", None)
if observer is not None:
observer.update_deferred_stats(value)
return

call_observer(
module=module,
base_name=base_name,
Expand All @@ -195,44 +211,40 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):

def calibrate_input_hook(module: Module, args: Any):
"""
Hook to calibrate input activations.
Will call the observers to update the scales/zp before applying
input QDQ in the module's forward pass.
Hook to accumulate input activation statistics (min/max) in the observer.
Scale and zero_point are not written here; they are computed once per subgraph
at epoch end via flush_activation_qparams.
"""
args = args[0] if isinstance(args, tuple) else args
calibrate_activations(module, value=args, base_name="input")
calibrate_activations(module, value=args, base_name="input", stats_only=True)


def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
"""
Hook to calibrate output activations.
Will call the observers to update the scales/zp before applying
output QDQ.
Hook to accumulate output activation statistics (min/max) in the observer.
Scale and zero_point are not written here; they are computed once per subgraph
at epoch end via flush_activation_qparams.
Note: forward_quantize is intentionally absent — hooks only collect statistics.
"""
calibrate_activations(
module,
value=output,
base_name="output",
)
output = forward_quantize(
module=module,
value=output,
base_name="output",
args=module.quantization_scheme.output_activations,
stats_only=True,
)
return output


def calibrate_query_hook(module: Module, query_states: torch.Tensor):
calibrate_activations(module, query_states, base_name="q")
calibrate_activations(module, query_states, base_name="q", stats_only=True)


def calibrate_key_hook(module: Module, key_states: torch.Tensor):
calibrate_activations(module, key_states, base_name="k")
calibrate_activations(module, key_states, base_name="k", stats_only=True)


def calibrate_value_hook(module: Module, value_states: torch.Tensor):
calibrate_activations(module, value_states, base_name="v")
calibrate_activations(module, value_states, base_name="v", stats_only=True)


def apply_calibration_status(module: Module):
Expand Down Expand Up @@ -273,3 +285,29 @@ def reset_quantization_status(model: Module):
for module in model.modules():
if hasattr(module, "quantization_status"):
delattr(module, "quantization_status")


def flush_activation_qparams(module: Module):
"""
Compute and write final activation qparams from each observer's accumulated
running statistics, then free those statistics to reduce memory.

This is called once at SEQUENTIAL_EPOCH_END for each subgraph, replacing the
per-batch qparam updates that were previously triggered by calibration hooks.
It is a no-op for modules with no quantization scheme or no activation observers.

Note: weight observers are not touched here — weight qparams are always computed
up-front in ``on_start`` via ``update_weight_zp_scale``.

apply to targeted modules with:
for _, module in match_named_modules(...):
flush_activation_qparams(module)

:param module: module to flush activation qparams for
"""
scheme = getattr(module, "quantization_scheme", None)
if scheme is None:
return

for base_name in ("input", "output", "q", "k", "v"):
calibrate_module_from_observer(module, base_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this name is bad, sounds like its doing weight quantization?

13 changes: 12 additions & 1 deletion src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.quantization.calibration import (
flush_activation_qparams,
update_weight_global_scale,
update_weight_zp_scale,
)
Expand Down Expand Up @@ -65,7 +66,9 @@ def on_initialize(self, state: State, **kwargs) -> bool:

def on_start(self, state: State, event: Event, **kwargs):
"""
Begin calibrating activations and weights. Calibrate weights only once on start
Begin calibrating activations and weights. Calibrate weights only once on start.
Activation qparams are computed once per subgraph at SEQUENTIAL_EPOCH_END via
flush_activation_qparams, rather than per batch.
"""
self.started_ = True
QuantizationMixin.start_calibration(self, state.model)
Expand Down Expand Up @@ -95,6 +98,14 @@ def on_event(self, state: State, event: Event, **kwargs):
if not self.started_:
self.on_start(state, None)

if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
# Compute scale/zero_point once from accumulated running statistics,
# then free those stats to reduce memory.
for _, module in match_named_modules(
state.model, self.resolved_targets, self.ignore
):
flush_activation_qparams(module)

if event.type_ == EventType.CALIBRATION_EPOCH_END:
if not self.ended_:
self.on_end(state, None)
Expand Down
101 changes: 98 additions & 3 deletions src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
from compressed_tensors.registry.registry import RegistryMixin
from compressed_tensors.utils import align_module_device

from compressed_tensors.utils import align_module_device, update_offload_parameter
from llmcompressor.observers.helpers import flatten_for_calibration

__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple"]
__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple", "calibrate_module_from_observer"]

MinMaxTuple = Tuple[torch.Tensor, torch.Tensor]
ScaleZpTuple = Tuple[torch.Tensor, torch.Tensor]
Expand Down Expand Up @@ -74,6 +73,57 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
"""
raise NotImplementedError()

def update_deferred_stats(self, observed: torch.Tensor):
"""
Accumulate global min/max from an observed tensor into ``_deferred_min``
and ``_deferred_max`` on this observer.

Called by ``calibrate_activations`` in ``stats_only`` mode for ALL observer
types including ``MemorylessMinMaxObserver`` which has no ``past_min_vals``.

:param observed: activation tensor for this batch
"""
batch_min = observed.float().min()
batch_max = observed.float().max()

if not hasattr(self, "_deferred_min") or self._deferred_min is None:
self._deferred_min = batch_min
self._deferred_max = batch_max
else:
self._deferred_min = torch.min(self._deferred_min, batch_min)
self._deferred_max = torch.max(self._deferred_max, batch_max)

def get_accumulated_min_max(self) -> Optional[MinMaxTuple]:
"""
Return accumulated min/max populated by ``update_deferred_stats``.
Returns None if no batches have been seen yet.

Works for all observer types including ``MemorylessMinMaxObserver``.

:return: (min_vals, max_vals) tensors or None
"""
min_vals = getattr(self, "_deferred_min", None)
max_vals = getattr(self, "_deferred_max", None)
if min_vals is None or max_vals is None:
return None
return min_vals, max_vals

def clear_accumulated_stats(self):
"""
Delete accumulated running statistics to free memory after qparams have been
computed and written to the parent module.
"""
for attr in (
"_deferred_min",
"_deferred_max",
"past_min_vals",
"past_max_vals",
"past_global_min_vals",
"past_global_max_vals",
):
if hasattr(self, attr):
delattr(self, attr)

@torch.no_grad
def forward(self, observed: torch.Tensor) -> ScaleZpTuple:
"""
Expand Down Expand Up @@ -142,3 +192,48 @@ def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]):
"Cannot compute scale and zero points "
"without first computing global scale"
)


@torch.no_grad()
def calibrate_module_from_observer(
module: torch.nn.Module,
base_name: str,
) -> bool:
"""
Flush an observer's accumulated running statistics into the parent module's
quantization parameters (scale / zero_point), then free the running stats.

This is the deferred counterpart to ``call_observer``. Instead of accepting a
fresh activation tensor, it reads the min/max values that the observer has
already accumulated across all calibration batches and computes qparams from
those final statistics.

:param module: module whose ``{base_name}_observer`` attribute holds the observer
:param base_name: one of "input", "output", "q", "k", "v"
:return: True if qparams were updated, False if observer had no accumulated stats
"""
observer: Optional[Observer] = getattr(module, f"{base_name}_observer", None)
if observer is None:
return False

accumulated = observer.get_accumulated_min_max()
if accumulated is None:
return False

min_vals, max_vals = accumulated
global_scale = getattr(module, f"{base_name}_global_scale", None)

with align_module_device(module):
scales, zero_points = calculate_qparams(
min_vals=min_vals,
max_vals=max_vals,
quantization_args=observer.args,
global_scale=global_scale,
)
update_offload_parameter(module, f"{base_name}_scale", scales)
if hasattr(module, f"{base_name}_zero_point"):
update_offload_parameter(module, f"{base_name}_zero_point", zero_points)

# Free memory — running stats no longer needed
observer.clear_accumulated_stats()
return True
27 changes: 13 additions & 14 deletions src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Iterator

import torch
from compressed_tensors.quantization import disable_quantization, enable_quantization
from compressed_tensors.utils import disable_offloading
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
Expand All @@ -19,7 +20,6 @@
)
from llmcompressor.utils.dev import get_main_device
from llmcompressor.utils.helpers import (
DISABLE_QAC_MODIFIERS,
DisableQuantization,
calibration_forward_context,
)
Expand Down Expand Up @@ -111,18 +111,13 @@ def __call__(

LifecycleCallbacks.calibration_epoch_start()

# TODO: remove this to enable quantization aware calibration
# for GPTQ, AWQ and AutoRound.
disable_qac = any(
type(mod).__name__ in DISABLE_QAC_MODIFIERS
for mod in session.lifecycle.recipe.modifiers
)

with contextlib.ExitStack() as stack:
stack.enter_context(calibration_forward_context(model))
# Optionally disable quantization
if not dataset_args.quantization_aware_calibration or disable_qac:
stack.enter_context(DisableQuantization(model))
# Always disable quantization during calibration so that observer hooks
# accumulate statistics from unquantized activations. Quantization is
# re-enabled during the propagation pass so that downstream subgraphs
# receive realistic (quantized) inputs.
stack.enter_context(DisableQuantization(model))

# prepare intermediates cache
activations = IntermediatesCache.from_dataloader(
Expand All @@ -148,7 +143,7 @@ def __call__(
num_batches = len(dataloader)
use_prefetch = getattr(dataset_args, "sequential_prefetch", False)
with disable_offloading():
# do a preliminary pass to trigger modifier hooks
# calibration pass: hooks accumulate activation statistics
for batch_idx, inputs in _get_batches(
activations,
num_batches,
Expand All @@ -159,10 +154,13 @@ def __call__(
session.state.current_batch_idx = batch_idx
subgraph.forward(model, **inputs)

# flush accumulated stats -> write scale/zero_point once per subgraph
LifecycleCallbacks.sequential_epoch_end(subgraph)

# this pass does not trigger modifier hooks
# and is only used for capturing outputs of newly compressed modules
# propagation pass: modifier hooks are disabled but quantization is
# re-enabled so that compressed module outputs are quantized.
# This ensures downstream subgraphs receive realistic inputs.
model.apply(enable_quantization)
with HooksMixin.disable_hooks():
for batch_idx, inputs in _get_batches(
activations,
Expand All @@ -175,6 +173,7 @@ def __call__(
if subgraph_index < num_subgraphs - 1:
activations.update(batch_idx, output)
activations.delete(batch_idx, subgraph.consumed_names)
model.apply(disable_quantization)

# redundant, finish any remaining compression
LifecycleCallbacks.calibration_epoch_end()
Loading