From 79c7e86b2b2171bd652a3ed18c0da9ab7d675998 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 7 Oct 2025 17:41:44 -0400 Subject: [PATCH 1/2] refactor observers Signed-off-by: Kyle Sayers --- docs/observers.md | 6 +- .../modifiers/quantization/cache.py | 10 +- .../modifiers/quantization/calibration.py | 75 +--- .../quantization/gptq/gptq_quantize.py | 6 +- src/llmcompressor/observers/base.py | 320 +++-------------- src/llmcompressor/observers/helpers.py | 128 ++++++- src/llmcompressor/observers/min_max.py | 163 ++------- src/llmcompressor/observers/mse.py | 215 ++++------- .../modifiers/calibration/test_cache.py | 2 +- .../modifiers/calibration/test_lifecycle.py | 337 ++++++++++++++++++ .../modifiers/calibration/test_observers.py | 23 +- tests/llmcompressor/observers/test_helpers.py | 129 +++---- tests/llmcompressor/observers/test_min_max.py | 29 +- tests/llmcompressor/observers/test_mse.py | 39 +- 14 files changed, 733 insertions(+), 749 deletions(-) create mode 100644 tests/llmcompressor/modifiers/calibration/test_lifecycle.py diff --git a/docs/observers.md b/docs/observers.md index 342c7dec9..c5dd978a5 100644 --- a/docs/observers.md +++ b/docs/observers.md @@ -65,7 +65,11 @@ from llmcompressor.observers import Observer from compressed_tensors.quantization.quant_args import QuantizationArgs args = QuantizationArgs(num_bits=4, strategy="group", group_size=128) -observer = Observer.load_from_registry("minmax", quantization_args=args) +observer = Observer.load_from_registry( + "minmax", + base_name="weight", + quantization_args=args, +) x = torch.randn(64, 512) scale, zero_point = observer(x) diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py index b09b41812..53eca8d07 100644 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ b/src/llmcompressor/modifiers/quantization/cache.py @@ -86,13 +86,15 @@ def update( """ if len(self.k_observers) <= layer_idx: - k_observer_name = self.quantization_args.observer k_observer = Observer.load_from_registry( - k_observer_name, quantization_args=self.quantization_args + self.quantization_args.observer, + base_name="k", + args=self.quantization_args, ) - v_observer_name = self.quantization_args.observer v_observer = Observer.load_from_registry( - v_observer_name, quantization_args=self.quantization_args + self.quantization_args.observer, + base_name="v", + args=self.quantization_args, ) # NOTE: User may ignore some layers in configuration, diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 96b400d63..5540532c9 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -5,6 +5,7 @@ from compressed_tensors.quantization import ( DynamicType, KVCacheScaleType, + QuantizationArgs, QuantizationScheme, QuantizationStatus, QuantizationStrategy, @@ -19,12 +20,6 @@ from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain -DEFAULT_MAXSHRINK = 0.20 -DEFAULT_PATIENCE = 5 -DEFAULT_AVERAGING_CONSTANT = 0.01 -DEFAULT_GRID = 100.0 -DEFAULT_NORM = 2.4 - __all__ = [ "initialize_observer", "update_weight_zp_scale", @@ -54,31 +49,19 @@ def initialize_observer( :param base_name: str used to name the observer attribute """ - - arg_name = "weights" if base_name == "weight" else f"{base_name}_activations" - quantization_scheme = getattr(module, "quantization_scheme", None) - if not quantization_scheme: - # no quantization scheme nothing to do - return - - quantization_args = getattr(quantization_scheme, arg_name, None) - # dont need observers for dynamic - if quantization_args is not None and quantization_args.dynamic in ( - False, - DynamicType.LOCAL, - ): - observer_kwargs = quantization_args.observer_kwargs or {} + if base_name == "weight": + arg_name = "weights" + elif base_name == "output": + arg_name = "output_activations" + else: # input, q, k, v + arg_name = "input_activations" + + args: QuantizationArgs = getattr_chain( + module, f"quantization_scheme.{arg_name}", None + ) + if args is not None and args.dynamic is not True: observer = Observer.load_from_registry( - quantization_args.observer, - quantization_args=quantization_args, - averaging_constant=observer_kwargs.get( - "averaging_constant", DEFAULT_AVERAGING_CONSTANT - ), - # used by mse observer only, will be ignored by minmax observer - maxshrink=observer_kwargs.get("maxshrink", DEFAULT_MAXSHRINK), - patience=observer_kwargs.get("patience", DEFAULT_PATIENCE), - grid=observer_kwargs.get("grid", DEFAULT_GRID), - norm=observer_kwargs.get("norm", DEFAULT_NORM), + args.observer, base_name=base_name, args=args, module=module ) module.register_module(f"{base_name}_observer", observer) @@ -100,36 +83,17 @@ def call_observer( base_name is "weight", then the module's weight tensor will be used """ with align_module_device(module): - if base_name == "weight": - value = module.weight - g_idx = getattr(module, "weight_g_idx", None) - elif value is not None: - g_idx = None - else: - raise ValueError( - "Must provide a value to observe if not using weight observer" - ) - - observer = getattr(module, f"{base_name}_observer") + value = module.weight if base_name == "weight" else value + observer: Observer = getattr(module, f"{base_name}_observer") if should_calculate_gparam: - global_scale = observer( - value, - should_calculate_gparam=True, - ) + global_scale = observer.get_global_scale(value) update_offload_parameter(module, f"{base_name}_global_scale", global_scale) - else: - global_scale = getattr(module, f"{base_name}_global_scale", None) if should_calculate_qparams: - updated_scale, updated_zero_point = observer( - value, g_idx=g_idx, global_scale=global_scale - ) - # register or update scale & zero_point parameters (supports block shapes) - scale_name = f"{base_name}_scale" - zp_name = f"{base_name}_zero_point" - update_offload_parameter(module, scale_name, updated_scale) - update_offload_parameter(module, zp_name, updated_zero_point) + scale, zero_point = observer(value) + update_offload_parameter(module, f"{base_name}_scale", scale) + update_offload_parameter(module, f"{base_name}_zero_point", zero_point) def update_weight_global_scale(module: Module): @@ -148,7 +112,6 @@ def update_weight_global_scale(module: Module): should_calculate_gparam=True, should_calculate_qparams=False, ) - module.weight_observer.reset() def update_weight_zp_scale(module: Module): diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index 4392ed8cf..28926650f 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -95,8 +95,10 @@ def quantize_weight( # create observer for calculating quantization parameters observer = Observer.load_from_registry( - quant_args.observer, - quantization_args=quant_args, + "minmax", + base_name="weight", + args=quant_args, + module=module, averaging_constant=1.0, # ignore moving average ) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 6ca6e203c..4af5c37e3 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -1,17 +1,16 @@ -from math import ceil -from typing import Any, Iterable, Optional, Tuple, Union +from abc import abstractmethod +from typing import Optional, Tuple +from weakref import ref import torch from compressed_tensors import InternalModule from compressed_tensors.quantization.quant_args import ( - FP8_E4M3_DATA, QuantizationArgs, - QuantizationStrategy, ) -from compressed_tensors.quantization.utils import is_fp4 +from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam from compressed_tensors.registry.registry import RegistryMixin -from loguru import logger -from torch import FloatTensor, IntTensor, Tensor + +from llmcompressor.observers.helpers import flatten_for_calibration __all__ = ["Observer"] @@ -25,287 +24,70 @@ class Observer(InternalModule, RegistryMixin): def __init__( self, - quantization_args: QuantizationArgs, + base_name: str, + args: QuantizationArgs, + module: Optional[torch.nn.Module] = None, + **observer_kwargs, ): - self.quantization_args: QuantizationArgs = quantization_args super().__init__() - self._scale = None - self._zero_point = None - self._num_observed_tokens = None - - @torch.no_grad() - def forward( - self, - observed: Tensor, - g_idx: Optional[Tensor] = None, - global_scale: Optional[Tensor] = None, - should_calculate_gparam: bool = False, - ) -> Tuple[FloatTensor, IntTensor]: - """ - maps directly to get_qparams - :param observed: optional observed tensor from which to calculate - quantization parameters - :param g_idx: optional mapping from column index to group index - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point based on last observed value - """ - self.record_observed_tokens(observed) - if should_calculate_gparam: - return self.get_gparam(observed=observed) - return self.get_qparams( - observed=observed, - g_idx=g_idx, - global_scale=global_scale, - ) + self.module = ref(module) if module is not None else None + self.base_name = base_name + self.args = args - def calculate_qparams( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - global_scale: Optional[Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: optional id for tracking separate statistics when different - ranges of observed tensors are passed, useful for sharding tensors by - group_size or block quantization - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point derived from the observed tensor - """ - raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") + # populate observer kwargs + self.args.observer_kwargs = self.args.observer_kwargs or {} + self.args.observer_kwargs.update(observer_kwargs) - def calculate_gparam( - self, - observed: Tensor, - ) -> torch.Tensor: - """ - :param observed: observed tensor to calculate quantization parameters for - :return: global scale derived from the observed tensor - """ - raise NotImplementedError(f"{self.__class__} must implement calculate_gparam") + # used for moving averages and testing + self.min_vals = None + self.max_vals = None - def post_calculate_qparams(self) -> None: - """ - Run any logic specific to its observers after running calculate_qparams + @abstractmethod + def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ + Calculates updated scales and zero points from observed value + (weight, activation, or attention state). - def get_gparam(self, observed: Tensor): - """ - Function to derive a global scale parameter - :param observed: observed tensor to calculate global parameters - from - :return: derived global scale + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum value and maximum value whose shapes are (*qparam_shape, ) """ - if self.quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: - return self.calculate_gparam(observed) - raise NotImplementedError( - "global parameter generation is only supported for TENSOR_GROUP" - ) + raise NotImplementedError() - def get_qparams( - self, - observed: Optional[Tensor] = None, - g_idx: Optional[Tensor] = None, - global_scale: Optional[Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: + def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Convenience function to wrap overwritten calculate_qparams - adds support to make observed tensor optional and support for tracking latest - calculated scale and zero point + Calculates updated scales and zero points from observed value + (weight, activation, or attention state). - :param observed: optional observed tensor to calculate quantization parameters - from - :param g_idx: optional mapping from column index to group index - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point based on last observed value + :param observed: value being observed + :return: calibrated scale and zero point """ - if observed is not None: - group_size = self.quantization_args.group_size - - if self.quantization_args.strategy == QuantizationStrategy.TENSOR: - # re-calculate scale and zero point, update the stored value - self._scale, self._zero_point = self.calculate_qparams(observed) - - elif self.quantization_args.strategy in ( - QuantizationStrategy.TENSOR_GROUP, - QuantizationStrategy.GROUP, - ): - rows = observed.shape[0] - columns = observed.shape[1] - num_groups = int(ceil(columns / group_size)) - if num_groups * group_size != columns: - logger.bind(log_once=True).warning( - "Attempting to quantize a module weight whose columns " - f"({columns}) are not divisible by group_size ({group_size}). " - "This scheme is not supported by vLLM, please consider " - "adjusting the group_size for modules with this number of " - "columns", - ) - - self._scale = torch.empty( - (rows, num_groups), dtype=observed.dtype, device=observed.device - ) - if is_fp4(quantization_args=self.quantization_args): - zp_dtype = FP8_E4M3_DATA.dtype - else: - zp_dtype = self.quantization_args.pytorch_dtype() + g_idx = self._get_module_param("g_idx") + global_scale = self._get_module_param("global_scale") - self._zero_point = torch.empty( - (rows, num_groups), dtype=zp_dtype, device=observed.device - ) + observed = flatten_for_calibration(observed, self.base_name, self.args, g_idx) + self.min_vals, self.max_vals = self.get_min_max(observed) - # support column-order (default) quantization as well as other orderings - # such as activation ordering. Below checks if g_idx has initialized - is_column_order = g_idx is None or -1 in g_idx - if is_column_order: - group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) - else: - group_indices, group_sizes = torch.unique(g_idx, return_counts=True) - group_sizes = group_sizes[torch.argsort(group_indices)] - - perm = torch.argsort(g_idx) - observed = observed.index_select(dim=1, index=perm) - - # TODO: experiment with vectorizing for loop for performance - end = 0 - for group_index, group_count in enumerate(group_sizes): - start = end - end = start + group_count - scale, zero_point = self.get_qparams_along_dim( - observed[:, start:end], - 0, - tensor_id=group_index, - global_scale=global_scale, - ) - - self._scale[:, group_index] = scale.squeeze(1) - self._zero_point[:, group_index] = zero_point.squeeze(1) - - elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: - # assume observed is transposed, because its the output, hence use dim 0 - self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) - - elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: - # use dim 1, assume the obsersed.shape = [batch, token, hidden] - # should be batch, token - self._scale, self._zero_point = self.get_qparams_along_dim( - observed, - dim={0, 1}, - ) - - elif self.quantization_args.strategy == QuantizationStrategy.BLOCK: - # Block-wise quantization: one scale/zero_point per block of shape - # [block_rows, block_cols] - rows, cols = observed.shape[:2] - bs = self.quantization_args.block_structure - if not ( - isinstance(bs, (list, tuple)) - and len(bs) == 2 - and all(isinstance(x, int) for x in bs) - ): - raise ValueError( - f"Invalid block_structure '{bs}'. " - f"Must be a list of two ints [rows, cols]." - ) - block_rows, block_cols = bs - num_br = int(ceil(rows / block_rows)) - num_bc = int(ceil(cols / block_cols)) - - # allocate per-block scale and zero_point - self._scale = torch.empty( - (num_br, num_bc), dtype=observed.dtype, device=observed.device - ) - - # Use same dtype logic as GROUP strategy for zero_point - if is_fp4(quantization_args=self.quantization_args): - zp_dtype = FP8_E4M3_DATA.dtype - else: - zp_dtype = self.quantization_args.pytorch_dtype() - - self._zero_point = torch.empty( - (num_br, num_bc), dtype=zp_dtype, device=observed.device - ) - - # compute qparams for each block - for i in range(num_br): - r0 = i * block_rows - r1 = min((i + 1) * block_rows, rows) - for j in range(num_bc): - c0 = j * block_cols - c1 = min((j + 1) * block_cols, cols) - # reduce across both dims to get one scale and zp per block - # Use unique tensor_id for each block to maintain separate stats - block_tensor_id = f"block_{i}_{j}" - scale_bp, zp_bp = self.calculate_qparams( - observed[r0:r1, c0:c1], - reduce_dims=(0, 1), - tensor_id=block_tensor_id, - ) - self._scale[i, j] = scale_bp - self._zero_point[i, j] = zp_bp - - return self._scale, self._zero_point - - def get_qparams_along_dim( - self, - observed, - dim: Union[int, Iterable[int]], - tensor_id: Optional[Any] = None, - global_scale: Optional[Tensor] = None, - ): - if isinstance(dim, int): - dim = [dim] - dim = set(dim) - - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) - return self.calculate_qparams( - observed, - reduce_dims=reduce_dims, - tensor_id=tensor_id, + return calculate_qparams( + min_vals=self.min_vals, + max_vals=self.max_vals, + quantization_args=self.args, global_scale=global_scale, ) - def record_observed_tokens(self, batch_tensor: Tensor): + def get_global_scale(self, observed: torch.Tensor) -> torch.nn.Parameter: """ - Counts the number of tokens observed during the - forward passes. The count is aggregated in the - _num_observed_tokens attribute of the class. + Calculates updated global scale from observed value - Note: The batch_tensor is expected to have two dimensions - (batch_size * sequence_length, num_features). This is the - general shape expected by the forward pass of the expert - layers in a MOE model. If the input tensor does not have - two dimensions, the _num_observed_tokens attribute will be set - to None. + :param observed: value being observed + :return: calibrated global parameter """ - if not isinstance(batch_tensor, Tensor): - raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}") + observed = observed.reshape((1, 1, -1)) # per tensor reshape + min_vals, max_vals = self.get_min_max(observed) + return generate_gparam(min_vals, max_vals) - if batch_tensor.ndim != 2: - logger.debug( - "The input tensor is expected to have two dimensions " - "(batch_size * sequence_length, num_features). " - f"The input tensor has {batch_tensor.ndim} dimensions." - ) - return + def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]: + if self.module is None: + return None - if self._num_observed_tokens is None: - # initialize the count - self._num_observed_tokens = 0 - - # batch_tensor (batch_size * sequence_length, num_features) - # observed_tokens (batch_size * sequence_length) - observed_tokens, _ = batch_tensor.shape - self._num_observed_tokens += observed_tokens - - def reset(self): - """ - Reset the state of the observer - """ - self._num_observed_tokens = None - self._scale = None - self._zero_point = None + return getattr(self.module(), f"{self.base_name}_{name}", None) diff --git a/src/llmcompressor/observers/helpers.py b/src/llmcompressor/observers/helpers.py index 5cd32ff64..4560da1b8 100644 --- a/src/llmcompressor/observers/helpers.py +++ b/src/llmcompressor/observers/helpers.py @@ -7,25 +7,125 @@ pruning operations. """ -from collections import Counter +from typing import Optional import torch +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy +from compressed_tensors.quantization.utils import strategy_cdiv -__all__ = ["get_observer_token_count"] +__all__ = ["flatten_for_calibration"] -def get_observer_token_count(module: torch.nn.Module) -> Counter: +def flatten_for_calibration( + value: torch.Tensor, + base_name: str, + args: QuantizationArgs, + g_idx: Optional[torch.Tensor] = None, +) -> torch.Tensor: """ - Parse the module and return the number of tokens observed by - each module's observer. + Reshapes the value according to the quantization strategy for the purposes of + scale/zp calibration. The value after flattening has the following shape: - :param module: module to parse - :return: counter with the number of tokens observed by each observer + `(num_observations, *qparam_shape, group_size)` + + The first dim is the number of observations (usually the batch size times number of + tokens), the middle dims are the dimension of the scales, and the last dim is the + number of elements being quantized per group. + + :param value: value being flattened + :param base_name: weight, input, output, q/k/v. Used to characterize the value as + being a weight, activation, or attention state + :param args: quantization args for determining how the value is flattened + :param g_idx: optional gidx for weight activation ordering + :return: value which has been reshaped for calibration """ - token_counts = Counter() - for name, module in module.named_modules(): - if name.endswith(".input_observer"): - token_counts[name.replace(".input_observer", "")] = ( - module._num_observed_tokens - ) - return token_counts + if base_name == "weight": + return _flatten_weight(value, args, g_idx) + elif base_name in ("input", "output"): + return _flatten_activation(value, args) + elif base_name in ("q", "k", "v"): + return _flatten_attention(value, args) + else: + raise ValueError(f"Unknown quantization base name: {base_name}") + + +def _flatten_weight( + value: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None +): + if args.strategy == QuantizationStrategy.TENSOR: + # (1, 1, num_weight_elems) + return value.reshape((1, 1, -1)) + + if args.strategy == QuantizationStrategy.TOKEN: + raise ValueError("Token quantization cannot be applied to weights") + + if args.strategy == QuantizationStrategy.CHANNEL: + # (1, num_rows, 1, num_cols) + return value.unsqueeze(-2).unsqueeze(0) + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + if g_idx is not None: + value = value.index_select(dim=1, index=torch.argsort(g_idx)) + + # (1, num_rows, num_groups, group_size) + return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0) + + if args.strategy == QuantizationStrategy.BLOCK: + # (1, num_block_rows, num_block_cols, block_width * block_height) + block_height, block_width = args.block_structure + rows, cols = value.shape + block_rows = strategy_cdiv(rows, block_height, args.strategy, strict=True) + block_cols = strategy_cdiv(cols, block_width, args.strategy, strict=True) + return ( + value.reshape(block_rows, block_height, block_cols, block_width) + .transpose(1, 2) + .flatten(-2, -1) + .unsqueeze(0) + ) + + assert False, f"Unknown strategy {args.strategy}" + + +def _flatten_activation(value: torch.Tensor, args: QuantizationArgs): + if args.strategy == QuantizationStrategy.TENSOR: + # (batch_size * seq_len, 1, hidden_dim) + return value.reshape((-1, 1, value.size(-1))) + + if args.strategy == QuantizationStrategy.TOKEN: + # (batch_size, seq_len, hidden_dim) + # warning: token quantization uses `compute_dynamic_scales_and_zp` + return value + + if args.strategy == QuantizationStrategy.CHANNEL: + raise ValueError("Channel quantization cannot be applied to activations") + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + # (batch_size * seq_len, num_groups, group_size) + # warning: group activation quantization uses compute_dynamic_scales_and_zp + return value.flatten(0, 1).unflatten(-1, (-1, args.group_size)) + + if args.strategy == QuantizationStrategy.BLOCK: + raise ValueError("Block quantization cannot be applied to activations") + + assert False, f"Unknown strategy {args.strategy}" + + +def _flatten_attention(value: torch.Tensor, args: QuantizationArgs): + if args.strategy == QuantizationStrategy.TENSOR: + # (batch_size, seq_len, num_heads, head_dim) + # (batch_size * seq_len, 1, num_heads * head_dim) + return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2) + + if args.strategy == QuantizationStrategy.TOKEN: + raise ValueError("Token quantization cannot be applied to attention") + + if args.strategy == QuantizationStrategy.CHANNEL: + raise ValueError("Channel quantization cannot be applied to attention") + + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + raise ValueError("Group quantization cannot be applied to attention") + + if args.strategy == QuantizationStrategy.BLOCK: + raise ValueError("Block quantization cannot be applied to attention") + + assert False, f"Unknown strategy {args.strategy}" diff --git a/src/llmcompressor/observers/min_max.py b/src/llmcompressor/observers/min_max.py index ce5c0e779..5dbe8f31e 100644 --- a/src/llmcompressor/observers/min_max.py +++ b/src/llmcompressor/observers/min_max.py @@ -1,13 +1,11 @@ -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import torch from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam -from compressed_tensors.utils import deprecated from llmcompressor.observers.base import Observer -__all__ = ["MinMaxObserver", "MovingAverageMinMaxObserver"] +__all__ = ["MinMaxObserver"] @Observer.register("minmax") @@ -20,142 +18,39 @@ class MinMaxObserver(Observer): def __init__( self, - quantization_args: QuantizationArgs, - averaging_constant: float = 0.01, - **kwargs, + base_name: str, + args: QuantizationArgs, + module: Optional[torch.nn.Module] = None, + **observer_kwargs, ): - super().__init__(quantization_args=quantization_args) + super().__init__(base_name, args, module, **observer_kwargs) - self.min_val = {} - self.max_val = {} - self.averaging_constant = averaging_constant + observer_kwargs = self.args.observer_kwargs + self.averaging_constant = observer_kwargs.get("averaging_constant", 0.01) - def calculate_updated_min_max( - self, - observed: torch.Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - ): + def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Updates the observed min and max using a moving average smoothed by the - averaging_constant. Set the averaging_constant to 1.0 to disable averaging. + Calculates updated scales and zero points from observed value using the absolute + min and max value. If `averaging_constant` is specified, then subsequent calls + will affect a moving average by the specified constant. - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :return: updated min and max values + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum value and maximum value whose shapes are (*qparam_shape, ) """ - tensor_id = tensor_id or "default" - - if not reduce_dims: - min_val, max_val = torch.aminmax(observed) - else: - min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) - max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) - - # early stopping, save some computation and memory - if self.averaging_constant == 1.0: - return min_val, max_val - - running_min_val = self.min_val.get(tensor_id, None) - running_max_val = self.max_val.get(tensor_id, None) - - if running_min_val is None or running_max_val is None: - updated_min_val = min_val - updated_max_val = max_val - else: - updated_min_val = running_min_val + self.averaging_constant * ( - min_val - running_min_val - ) - updated_max_val = running_max_val + self.averaging_constant * ( - max_val - running_max_val - ) - - self.min_val[tensor_id] = updated_min_val - self.max_val[tensor_id] = updated_max_val - return updated_min_val, updated_max_val + min_vals = torch.amin(observed, dim=(0, -1)) + max_vals = torch.amax(observed, dim=(0, -1)) - def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor: - """ - Generate a global scale using the observed min and max. - - :param observed: observed tensor to calculate quantization parameters for - :return: updated global scale derived from the observed tensor - """ - - updated_min_val, updated_max_val = self.calculate_updated_min_max( - observed=observed - ) - return generate_gparam( - updated_min_val=updated_min_val, updated_max_val=updated_max_val - ) - - def calculate_qparams( - self, - observed: torch.Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ) -> Tuple[torch.FloatTensor, torch.IntTensor]: - """ - Generate a scale and zero-point using the observed min and max. - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point derived from the observed tensor - """ - - updated_min_val, updated_max_val = self.calculate_updated_min_max( - observed=observed, tensor_id=tensor_id, reduce_dims=reduce_dims - ) - return calculate_qparams( - min_vals=updated_min_val, - max_vals=updated_max_val, - quantization_args=self.quantization_args, - global_scale=global_scale, - ) - - def get_qparams_along_dim( - self, - observed: torch.Tensor, - dim: int, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ): - """ - Calculate quantization parameters along the specified dimension - """ - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, - reduce_dims=reduce_dims, - tensor_id=tensor_id, - global_scale=global_scale, - ) - - def reset(self): - """ - Reset the state of the observer, including min and maximum values - """ - super().reset() - self.min_val = {} - self.max_val = {} + if self.min_vals is not None and self.averaging_constant != 1.0: + # FUTURE: consider scaling by num observations (first dim) + # rather than reducing by first dim + min_vals = self._lerp(self.min_vals, min_vals, self.averaging_constant) + max_vals = self._lerp(self.max_vals, max_vals, self.averaging_constant) + return min_vals, max_vals -class MovingAverageMinMaxObserver(MinMaxObserver): - @deprecated( - message=( - "The class name `MovingAverageMinMaxObserver` has been deprecated, please " - "initialize with `MinMaxObserver` in the future" - ) - ) - def __new__(cls, *args, **kwargs): - return super().__new__(MinMaxObserver, *args, **kwargs) + def _lerp( + self, input: torch.Tensor, end: torch.Tensor, weight: float + ) -> torch.Tensor: + """torch lerp_kernel is not implemeneted for all data types""" + return (input * (1.0 - weight)) + (end * weight) diff --git a/src/llmcompressor/observers/mse.py b/src/llmcompressor/observers/mse.py index 419155f07..c33c08d6d 100644 --- a/src/llmcompressor/observers/mse.py +++ b/src/llmcompressor/observers/mse.py @@ -1,9 +1,12 @@ -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, +) from compressed_tensors.quantization.utils import calculate_qparams -from torch import FloatTensor, IntTensor, Tensor +from compressed_tensors.utils import patch_attr from llmcompressor.observers.base import Observer @@ -19,53 +22,58 @@ class MovingAverageMSEObserver(Observer): def __init__( self, - quantization_args: QuantizationArgs, - maxshrink: float = 0.2, - patience: int = 5, - averaging_constant: float = 0.01, - grid: float = 100.0, - norm: float = 2.4, - **kwargs, + base_name: str, + args: QuantizationArgs, + module: Optional[torch.nn.Module] = None, + **observer_kwargs, ): - super().__init__(quantization_args=quantization_args) + super().__init__(base_name, args, module, **observer_kwargs) - self.min_val = {} - self.max_val = {} - self.maxshrink = maxshrink - self.patience = patience - self.averaging_constant = averaging_constant - self.grid = grid - self.norm = norm + observer_kwargs = self.args.observer_kwargs + self.maxshrink = observer_kwargs.get("maxshrink", 0.20) + self.patience = observer_kwargs.get("patience", 5) + self.averaging_constant = observer_kwargs.get("averaging_constant", 0.01) + self.grid = observer_kwargs.get("grid", 100.0) + self.norm = observer_kwargs.get("norm", 2.4) - def calculate_mse_min_max( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - global_scale: Optional[torch.Tensor] = None, - ): + def get_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Computes the mse-clipped min and max values of the observed tensor by - optimizing for quantization error - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned values will be shaped (1,) along the reduced dimensions - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of min and max values derived from the observed tensor + Calculates updated scales and zero points from observed value. Minimum and + maximum values are chosen by grid searching across min/max values which minimize + quantization reconstruction loss. + + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum value and maximum value whose shapes are (*qparam_shape, ) """ - from compressed_tensors.quantization.lifecycle import fake_quantize + min_vals, max_vals = self._mse_min_max(observed) + + if self.min_vals is not None and self.averaging_constant != 1.0: + # FUTURE: consider scaling by num observations (first dim) + # rather than reducing by first dim + min_vals = self._lerp(self.min_vals, min_vals, self.averaging_constant) + max_vals = self._lerp(self.max_vals, max_vals, self.averaging_constant) - if not reduce_dims: - absolute_min_val, absolute_max_val = torch.aminmax(observed) - else: - absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) - absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) + return min_vals, max_vals + + def _mse_min_max(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Grid search for MSE-optimal min and max values + + :param observed: value being observed whose shape is + (num_observations, *qparam_shape, group_size) + :return: minimum and maximum values which minimize reconstruction error + """ + from compressed_tensors.quantization.lifecycle import fake_quantize + absolute_min_val = torch.amin(observed, dim=(0, -1)) + absolute_max_val = torch.amax(observed, dim=(0, -1)) best = torch.full_like( absolute_min_val, torch.finfo(absolute_min_val.dtype).max ) min_val = torch.ones_like(absolute_min_val) max_val = torch.zeros_like(absolute_max_val) + global_scale = self._get_module_param("global_scale") # Early stopping params no_improve_count = 0 @@ -78,24 +86,25 @@ def calculate_mse_min_max( candidate_scales, candidate_zero_points = calculate_qparams( min_vals=shrinked_min_val, max_vals=shrinked_max_val, - quantization_args=self.quantization_args, - global_scale=global_scale, - ) - q = fake_quantize( - observed, - candidate_scales, - candidate_zero_points, - self.quantization_args, + quantization_args=self.args, global_scale=global_scale, ) + # Note that observed.shape = (num_observations, *qparams_shape, group_size). + # For the purposes of fake quantization, this is equivalent to token quant + with patch_attr(self.args, "strategy", QuantizationStrategy.TOKEN): + q = fake_quantize( + observed, + candidate_scales.unsqueeze(-1), + candidate_zero_points.unsqueeze(-1), + self.args, + global_scale=global_scale, + ) + q -= observed q.abs_() q.pow_(self.norm) - if not reduce_dims: - err = torch.sum(q) - else: - err = torch.sum(q, reduce_dims, keepdims=True) + err = torch.sum(q, dim=(0, -1)) tmp = err < best if torch.any(tmp): @@ -110,104 +119,8 @@ def calculate_mse_min_max( return min_val, max_val - def calculate_updated_min_max( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Updates the mse-clipped min and max values of the observed tensor using - a moving average smoothed by the averaging_constant - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :param global_scale: optional scale to further scale local quantization scales - :return: updated min and max values derived from the observed value - """ - # TODO: will need to be expanded to support fp4 activations; - # currently not supported - min_val, max_val = self.calculate_mse_min_max( - observed, reduce_dims, global_scale=global_scale - ) - - running_min_val = self.min_val.get(tensor_id, None) - running_max_val = self.max_val.get(tensor_id, None) - - if running_min_val is None or running_max_val is None: - updated_min_val = min_val - updated_max_val = max_val - else: - updated_min_val = running_min_val + self.averaging_constant * ( - min_val - running_min_val - ) - updated_max_val = running_max_val + self.averaging_constant * ( - max_val - running_max_val - ) - - tensor_id = tensor_id or "default" - self.min_val[tensor_id] = updated_min_val - self.max_val[tensor_id] = updated_max_val - return updated_min_val, updated_max_val - - def calculate_qparams( - self, - observed: Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ) -> Tuple[FloatTensor, IntTensor]: - """ - Updates the mse-clipped min and max values of the observed tensor using - a moving average smoothed by the averaging_constant - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point derived from the observed tensor - """ - updated_min_val, updated_max_val = self.calculate_updated_min_max( - observed=observed, - tensor_id=tensor_id, - reduce_dims=reduce_dims, - global_scale=global_scale, - ) - scale, zero_point = calculate_qparams( - min_vals=updated_min_val, - max_vals=updated_max_val, - quantization_args=self.quantization_args, - global_scale=global_scale, - ) - return scale, zero_point - - def get_qparams_along_dim( - self, - observed, - dim: int, - tensor_id: Optional[Any] = None, - global_scale: Optional[torch.Tensor] = None, - ): - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, - reduce_dims=reduce_dims, - tensor_id=tensor_id, - global_scale=global_scale, - ) - - def reset(self): - """ - Reset the state of the observer, including min and maximum values - """ - super().reset() - self.min_val = {} - self.max_val = {} + def _lerp( + self, input: torch.Tensor, end: torch.Tensor, weight: float + ) -> torch.Tensor: + """torch lerp_kernel is not implemeneted for all data types""" + return (input * (1.0 - weight)) + (end * weight) diff --git a/tests/llmcompressor/modifiers/calibration/test_cache.py b/tests/llmcompressor/modifiers/calibration/test_cache.py index 9b03234cf..70f0e6125 100644 --- a/tests/llmcompressor/modifiers/calibration/test_cache.py +++ b/tests/llmcompressor/modifiers/calibration/test_cache.py @@ -29,7 +29,7 @@ def test_is_quantized_cache_singleton(): args = QuantizationArgs() cache = QuantizedKVParameterCache(args) observer = args.observer - observer = Observer.load_from_registry(observer, quantization_args=args) + observer = Observer.load_from_registry(observer, base_name="k", args=args) tensor = torch.tensor([1, 2, 3]) cache.k_scales.append(tensor) diff --git a/tests/llmcompressor/modifiers/calibration/test_lifecycle.py b/tests/llmcompressor/modifiers/calibration/test_lifecycle.py new file mode 100644 index 000000000..dae405463 --- /dev/null +++ b/tests/llmcompressor/modifiers/calibration/test_lifecycle.py @@ -0,0 +1,337 @@ +import pytest +import torch +from compressed_tensors.quantization import ( + QuantizationScheme, + forward_quantize, + initialize_module_for_quantization, + initialize_qparams, +) +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_config import QuantizationStatus + +from llmcompressor.modifiers.quantization.calibration import initialize_observer + + +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", # equivalent to token + ), + torch.tensor([0.0]), + torch.tensor([23.0]), + torch.tensor( + [ + [0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250], + [6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500], + [12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.85, + ), + # token is not supported + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="channel", + ), + torch.tensor([[0], [6], [12], [18]]), + torch.tensor([[5], [11], [17], [23]]), + torch.tensor( + [ + [0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875], + [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], + [11.3125, 13.6250, 13.6250, 15.8750, 15.8750, 15.8750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.45, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=3, + ), + torch.tensor([[0, 3], [6, 9], [12, 15], [18, 21]]), + torch.tensor([[2, 5], [8, 11], [14, 17], [20, 23]]), + torch.tensor( + [ + [0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875], + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], + [11.1875, 13.0625, 13.0625, 15.8750, 15.8750, 15.8750], + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], + ], + ), + 0.45, + ), + ( + QuantizationArgs( + num_bits=4, + type="float", # tensor group requires FP4 + symmetric=True, + strategy="tensor_group", # requires float4 + group_size=3, + ), + torch.tensor([[0, 3], [6, 9], [12, 15], [18, 21]]), + torch.tensor([[2, 5], [8, 11], [14, 17], [20, 23]]), + torch.tensor( + [ + [0.0000, 1.0234, 2.0469, 3.2812, 3.2812, 4.9375], + [5.4688, 8.1875, 8.1875, 10.6875, 10.6875, 10.6875], + [9.8750, 14.7500, 14.7500, 16.3750, 16.3750, 16.3750], + [19.7500, 19.7500, 19.7500, 23.0000, 23.0000, 23.0000], + ], + ), + 1.1, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="block", + block_structure=[2, 3], + ), + torch.tensor([[0, 3], [12, 15]]), + torch.tensor([[8, 11], [20, 23]]), + torch.tensor( + [ + [0.0000, 1.0703, 2.1406, 2.9375, 4.4062, 4.4062], + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], + [10.6875, 13.3750, 13.3750, 15.3125, 15.3125, 18.3750], + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], + ], + ), + 0.5, + ), + ], +) +def test_static_weight_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + weight = tensor([[ 0, 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17], + [18, 19, 20, 21, 22, 23]]) + """ + # set up weight + input_size, output_size = 6, 4 + linear = torch.nn.Linear(input_size, output_size, bias=False) + linear.weight.data = torch.arange( + input_size * output_size, dtype=torch.bfloat16 + ).reshape(output_size, input_size) + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], weights=args) + initialize_module_for_quantization(linear, scheme) + assert getattr(linear, "quantization_scheme") is scheme + initialize_observer(linear, "weight") + + # calibrate_global_scale + if hasattr(linear, "weight_global_scale"): + global_scale = linear.weight_observer.get_global_scale(linear.weight) + linear.weight_global_scale.data = global_scale + + # calibrate quantization parameters + scale, zero_point = linear.weight_observer(linear.weight) + linear.weight_scale.data = scale + linear.weight_zero_point.data = zero_point + assert torch.equal(linear.weight_observer.min_vals, exp_min_val) + assert torch.equal(linear.weight_observer.max_vals, exp_max_val) + + # forward pass + input = torch.eye(input_size, dtype=torch.bfloat16) + output = linear(input) + + assert torch.allclose(output.T, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output.T, linear.weight) <= exp_loss + + +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", + ), + torch.tensor([0.0]), + torch.tensor([11.0]), + torch.tensor( + [ + [ + [0.0000, 1.4688, 1.4688, 2.9375, 4.4062, 4.4062], + [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], + ] + ] + ), + 0.2, + ), + # static token is not supported + # channel is not supported + # group is not supported + ( + QuantizationArgs( + num_bits=4, + type="float", # must be fp4 + symmetric=True, + strategy="tensor_group", + dynamic="local", + group_size=3, + ), + None, + None, + torch.tensor( + [ + [ + [0.0000, 0.9844, 1.9688, 3.4062, 3.4062, 5.1250], + [5.2500, 7.8750, 7.8750, 7.3438, 11.0000, 11.0000], + ] + ] + ), + 0.5, + ), + # block is not supported + # head is not supported + ], +) +def test_static_activation_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + input = tensor([[ 0, 1, 2, 3, 4, 5] + [ 6, 7, 8, 9, 10, 11]]) + """ + # set up activation (and identity weight) + batch_size, seq_len, input_size = 1, 2, 6 + input = torch.arange( + (batch_size * seq_len * input_size), dtype=torch.bfloat16 + ).reshape((batch_size, seq_len, input_size)) + linear = torch.nn.Linear(input_size, input_size, bias=False) + linear.weight.data = torch.eye(input_size, dtype=torch.bfloat16) + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], input_activations=args) + initialize_module_for_quantization(linear, scheme) + assert getattr(linear, "quantization_scheme") is scheme + initialize_observer(linear, "input") + + # calibrate quantization parameters + def calibrate_input_hook(_, args): + if hasattr(linear, "input_global_scale"): + global_scale = linear.input_observer.get_global_scale(args[0]) + linear.input_global_scale.data = global_scale + + if linear.quantization_scheme.input_activations.dynamic is False: + scale, zero_point = linear.input_observer(args[0]) + linear.input_scale.data = scale + linear.input_zero_point.data = zero_point + + linear.register_forward_pre_hook(calibrate_input_hook) + + # calibration forward pass + output = linear(input) + + # check calibration + if exp_min_val is not None: + assert torch.equal(linear.input_observer.min_vals, exp_min_val) + if exp_max_val is not None: + assert torch.equal(linear.input_observer.max_vals, exp_max_val) + + # check forward pass + assert torch.allclose(output, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output, input) <= exp_loss + + +class MockAttention(torch.nn.Module): + pass + + +@pytest.mark.filterwarnings("ignore::UserWarning") # cpu offloading for MockAttention +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", + ), + torch.tensor([0.0]), + torch.tensor([11.0]), + torch.tensor( + [ + [ + [[0.0000, 1.4688, 1.4688], [2.9375, 4.4062, 4.4062]], + [[5.8750, 7.3438, 7.3438], [8.8125, 10.2500, 10.2500]], + ] + ] + ), + 0.19, + ), + # static token is not supported + # channel is not supported + # group is not supported + # tensor group is not supported + # block is not supported + ], +) +def test_static_attention_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + input = tensor([[[[ 0., 1., 2.], + [ 3., 4., 5.]], + [[ 6., 7., 8.], + [ 9., 10., 11.]]]]) + """ + # set up activation (and identity weight) + batch_size, seq_len, num_heads, head_dim = 1, 2, 2, 3 + input = torch.arange( + (batch_size * seq_len * num_heads * head_dim), dtype=torch.bfloat16 + ).reshape((batch_size, seq_len, num_heads, head_dim)) + attention = MockAttention() + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], input_activations=args) + initialize_qparams( + attention, "k", args, (num_heads, head_dim), observed_dtype=torch.bfloat16 + ) + attention.quantization_scheme = scheme + attention.quantization_status = QuantizationStatus.INITIALIZED + initialize_observer(attention, "k") + + # calibrate quantization parameters + if scheme.input_activations.dynamic is False: + scale, zero_point = attention.k_observer(input) + attention.k_scale.data = scale + attention.k_zero_point.data = zero_point + + # calibration forward pass + output = forward_quantize(attention, input, "k", scheme.input_activations) + + # check calibration + if exp_min_val is not None: + assert torch.equal(attention.k_observer.min_vals, exp_min_val) + if exp_max_val is not None: + assert torch.equal(attention.k_observer.max_vals, exp_max_val) + + # check forward pass + assert torch.allclose(output, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output, input) <= exp_loss diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index a742a48b2..57f4de40b 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -13,17 +13,17 @@ "shape,group_size,actorder", [ ((1, 1), None, False), - ((1, 1), 128, False), - ((1, 1), 128, True), + ((1, 1), 1, False), + ((1, 1), 1, True), ((64, 64), None, False), - ((64, 64), 128, False), - ((64, 64), 128, True), - ((1792, 4096), None, False), - ((1792, 4096), 128, False), - ((1792, 4096), 128, True), - ((3420, 64), None, False), - ((3420, 64), 128, False), - ((3420, 64), 128, True), + ((64, 64), 32, False), + ((64, 64), 32, True), + ((896, 4096), None, False), + ((896, 4096), 7, False), + ((896, 4096), 7, True), + ((512, 64), None, False), + ((512, 64), 128, False), + ((512, 64), 128, True), ], ) def test_observers_update(shape, group_size, actorder): @@ -49,8 +49,7 @@ def test_observers_update(shape, group_size, actorder): ("output", output), ): observer = getattr(module, f"{location}_observer") - g_idx = getattr(module, "g_idx", None) - updated_scale, updated_zero_point = observer(value, g_idx=g_idx) + updated_scale, updated_zero_point = observer(value) assert_alike(updated_scale, getattr(module, f"{location}_scale")) assert_alike(updated_zero_point, getattr(module, f"{location}_zero_point")) diff --git a/tests/llmcompressor/observers/test_helpers.py b/tests/llmcompressor/observers/test_helpers.py index 527176019..5b1909828 100644 --- a/tests/llmcompressor/observers/test_helpers.py +++ b/tests/llmcompressor/observers/test_helpers.py @@ -12,98 +12,61 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from compressed_tensors.quantization import ( - QuantizationConfig, - QuantizationStatus, - apply_quantization_config, + QuantizationArgs, + QuantizationScheme, + initialize_module_for_quantization, ) -from transformers import AutoModelForCausalLM, AutoTokenizer - -from llmcompressor.modifiers.quantization.calibration import ( - calibrate_input_hook, - initialize_observer, -) -from llmcompressor.observers.helpers import get_observer_token_count - - -def _prep_for_input_quant_calibration(module: torch.nn.Module): - quantization_scheme = getattr(module, "quantization_scheme", None) - if not quantization_scheme: - return - - module.register_forward_pre_hook(calibrate_input_hook) - module.quantization_status = QuantizationStatus.CALIBRATION +from llmcompressor.observers.helpers import flatten_for_calibration -def test_get_observer_token_count(): - model = AutoModelForCausalLM.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") - tokenizer = AutoTokenizer.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE") - model.eval() - config = QuantizationConfig( - format="fakequant", - quantization_status="calibration", - config_groups={ - "group_1": { - "input_activations": { - "num_bits": 8, - "type": "int", - "symmetric": False, - "strategy": "tensor", - }, - "targets": ["Linear"], - }, - }, - ) - apply_quantization_config(model, config) - model.apply(lambda module: initialize_observer(module, base_name="input")) - model.apply(_prep_for_input_quant_calibration) - - # start calibration - calib_list = [ - "I am a string that", - "is used for calibration so", - "that your model is", - "quantized properly.", - ] - total_num_tokens_observed = 0 - for calib_sample in calib_list: - calib_tensor = tokenizer(calib_sample, return_tensors="pt") - _ = model(**calib_tensor) - total_num_tokens_observed += len(calib_tensor.input_ids.flatten()) +def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor: + perm = torch.randperm(columns) + return torch.tensor([index // group_size for index in range(columns)])[perm] - counter = get_observer_token_count(model) - # filter out the None values - # (tokens, in the appropriate format, that were not observed by the model) - counter = {k: v for k, v in counter.items() if v is not None} +@pytest.mark.parametrize( + "args", + [ + QuantizationArgs(strategy="tensor"), + QuantizationArgs(strategy="tensor_group", group_size=4), + ], +) +def test_flatten_for_calibration_input(args): + module = torch.nn.Linear(8, 10) + scheme = QuantizationScheme(targets=[], input_activations=args) + initialize_module_for_quantization(module, scheme) - # iterate over all the layers in the model where the token count in the proper - # format is has been observed - for i in range(model.config.num_hidden_layers): - # fetch the tokens observed by the router - tokens_observed_by_router = counter.pop( - f"model.layers.{i}.block_sparse_moe.gate" - ) - assert tokens_observed_by_router == total_num_tokens_observed + input = torch.empty((3, 5, 8)) + input_flattened = flatten_for_calibration(input, "input", scheme.input_activations) + assert input_flattened.shape[1:-1] == module.input_scale.shape + assert input_flattened.shape[1:-1] == module.input_zero_point.shape - # fetch the sum of tokens observed by all the experts - sum_tokens_observed_by_experts = 0 - keys_for_this_layer = [ - k - for k in counter.keys() - if f"model.layers.{i}.block_sparse_moe.experts" in k - ] - for key in keys_for_this_layer: - sum_tokens_observed_by_experts += counter.pop(key) - # each Mixtral expert is comprised of 3 linear layers, - # so we need to multiply by 3 - assert ( - sum_tokens_observed_by_experts - == total_num_tokens_observed * model.config.num_experts_per_tok * 3 - ) +@pytest.mark.parametrize( + "args,g_idx", + [ + (QuantizationArgs(strategy="tensor"), None), + (QuantizationArgs(strategy="channel"), None), + (QuantizationArgs(strategy="group", group_size=4), None), + (QuantizationArgs(strategy="group", group_size=4), make_dummy_g_idx(8, 4)), + (QuantizationArgs(strategy="tensor_group", group_size=4), None), + (QuantizationArgs(strategy="block", block_structure=[5, 4]), None), + ], +) +def test_flatten_for_calibration_weights(args, g_idx): + module = torch.nn.Linear(8, 10) + scheme = QuantizationScheme(targets=[], weights=args) + initialize_module_for_quantization(module, scheme) - # there are no more information in the counter - assert len(counter) == 0 + weight_flattened = flatten_for_calibration( + module.weight, + "weight", + scheme.weights, + g_idx=g_idx, + ) + assert weight_flattened.shape[1:-1] == module.weight_scale.shape + assert weight_flattened.shape[1:-1] == module.weight_zero_point.shape diff --git a/tests/llmcompressor/observers/test_min_max.py b/tests/llmcompressor/observers/test_min_max.py index 229c51ca7..8edc0d8e5 100644 --- a/tests/llmcompressor/observers/test_min_max.py +++ b/tests/llmcompressor/observers/test_min_max.py @@ -41,7 +41,7 @@ def test_min_max_observer(symmetric, expected_scale, expected_zero_point): ) observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) scale, zero_point = observer(tensor) assert round(scale.item(), 4) == expected_scale @@ -56,7 +56,7 @@ def test_min_max_observer_symmetric_scale_range(): weights = QuantizationArgs(num_bits=num_bits, symmetric=True, observer="minmax") observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) scale, zero_point = observer(tensor) # if symmetric, max symmetric_range = abs(-128) / 255 @@ -82,15 +82,17 @@ def test_min_max_observer_value_update(): tensor = inp num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=True, observer="minmax") + weights = QuantizationArgs( + num_bits=num_bits, strategy="tensor", symmetric=True, observer="minmax" + ) observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) curr_max = 1 curr_min = 1 for i, tensor in enumerate(tensors): observer(tensor) - curr_max = max(observer.max_val.get("default"), curr_max) - curr_min = min(observer.min_val.get("default"), curr_max) + curr_max = max(observer.max_vals[0], curr_max) + curr_min = min(observer.min_vals[0], curr_min) if i < 2: assert curr_max == 1 @@ -108,13 +110,20 @@ def test_g_idx(): input_shape = (128, 512) tensor = torch.rand(input_shape) weights = QuantizationArgs(num_bits=8, group_size=group_size, observer="minmax") + + module = torch.nn.Linear(512, 1) g_idx = make_dummy_g_idx(tensor.shape[1], group_size) + module.weight_g_idx = g_idx - observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) - scale_g_idx, zero_point_g_idx = observer(tensor, g_idx=g_idx) + observer = Observer.load_from_registry( + weights.observer, base_name="weight", args=weights, module=module + ) + scale_g_idx, zero_point_g_idx = observer(tensor) - observer.reset() + observer = Observer.load_from_registry( + weights.observer, base_name="weight", args=weights, module=module + ) + del module.weight_g_idx scale, zero_point = observer(tensor[:, torch.argsort(g_idx)]) assert scale_g_idx == pytest.approx(scale) diff --git a/tests/llmcompressor/observers/test_mse.py b/tests/llmcompressor/observers/test_mse.py index 1ba79495f..f741d4249 100644 --- a/tests/llmcompressor/observers/test_mse.py +++ b/tests/llmcompressor/observers/test_mse.py @@ -15,30 +15,45 @@ import pytest import torch +from compressed_tensors.quantization import fake_quantize from compressed_tensors.quantization.quant_args import QuantizationArgs from llmcompressor.observers import MovingAverageMSEObserver, Observer @pytest.mark.parametrize( - "symmetric,expected_scale,expected_zero_point", + "strategy,symmetric,exp_loss", [ - (True, 0.0078, 0), - (False, 0.0039, -128), + ("tensor", True, 4.8103e-06), + ("tensor", False, 1.1258e-06), + ("channel", True, 2.5675e-06), + ("channel", False, 2.3696e-07), + ("group", True, 3.1282e-06), + ("group", False, 1.3794e-07), + ("block", True, 2.8968e-06), + ("block", False, 5.6068e-07), ], ) -def test_mse_observer(symmetric, expected_scale, expected_zero_point): - tensor = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) +def test_mse_observer(strategy, symmetric, exp_loss): + tensor = torch.arange(24).reshape((6, 4)) / 24 num_bits = 8 - weights = QuantizationArgs(num_bits=num_bits, symmetric=symmetric, observer="mse") + weights = QuantizationArgs( + num_bits=num_bits, + strategy=strategy, + symmetric=symmetric, + group_size=(2 if strategy == "group" else None), + block_structure=([3, 2] if strategy == "block" else None), + observer="mse", + ) observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) - scale, zero_point = observer(tensor) - + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) assert isinstance(observer, MovingAverageMSEObserver) - assert round(scale.item(), 4) == expected_scale - assert round(zero_point.item(), 4) == expected_zero_point + + scale, zero_point = observer(tensor) + q_tensor = fake_quantize(tensor, scale, zero_point, weights) + mse_loss = torch.sum((tensor - q_tensor).abs_().pow_(2)) / tensor.numel() + assert mse_loss == pytest.approx(exp_loss, abs=1e-10) def test_mse_observer_symmetric_scale_range(): @@ -49,7 +64,7 @@ def test_mse_observer_symmetric_scale_range(): weights = QuantizationArgs(num_bits=num_bits, symmetric=True, observer="mse") observer = weights.observer - observer = Observer.load_from_registry(observer, quantization_args=weights) + observer = Observer.load_from_registry(observer, base_name="weight", args=weights) scale, zero_point = observer(tensor) # if symmetric, max symmetric_range = abs(-128) / 255 From 1c2d550a40c9e761dc803d3ac339d28c5d323656 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 7 Oct 2025 18:29:37 -0400 Subject: [PATCH 2/2] add torch inductor ignore Signed-off-by: Kyle Sayers --- tests/llmcompressor/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/llmcompressor/conftest.py b/tests/llmcompressor/conftest.py index 04fa58928..c0d976e5b 100644 --- a/tests/llmcompressor/conftest.py +++ b/tests/llmcompressor/conftest.py @@ -48,7 +48,7 @@ def _files_size_mb(path_list: List[str]) -> int: @pytest.fixture(scope="session", autouse=True) def check_for_created_files(): - ignore_dirs = ["__pycache__", "sparse_logs"] + ignore_dirs = ["__pycache__", "sparse_logs", "torchinductor"] start_files_root = _get_files(directory=r".", ignore_dirs=ignore_dirs) start_files_temp = _get_files( directory=tempfile.gettempdir(), ignore_dirs=["pytest-of"]