diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 2900f6bd3..dc15381d8 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -147,7 +147,6 @@ def update_weight_global_scale(module: Module): should_calculate_gparam=True, should_calculate_qparams=False, ) - module.weight_observer.reset() def update_weight_zp_scale(module: Module): @@ -199,6 +198,10 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: calculate_gparam = True + # (..., 1, hidden_dim) + # the second to last dim indicates that activations have one output channel + value = value.flatten(0, -1).unsqueeze(-2) + call_observer( module=module, base_name=base_name, diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index ea325c7dc..d010c1689 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -1,4 +1,3 @@ -from math import ceil from typing import Any, Iterable, Optional, Tuple, Union import torch @@ -8,7 +7,7 @@ QuantizationArgs, QuantizationStrategy, ) -from compressed_tensors.quantization.utils import is_fp4 +from compressed_tensors.quantization.utils import is_fp4, strategy_cdiv from compressed_tensors.registry.registry import RegistryMixin from loguru import logger from torch import FloatTensor, IntTensor, Tensor @@ -127,60 +126,54 @@ def get_qparams( :param global_scale: optional scale to further scale local quantization scales :return: tuple of scale and zero point based on last observed value """ - if observed is not None: - group_size = self.quantization_args.group_size + strategy = self.quantization_args.strategy - if self.quantization_args.strategy == QuantizationStrategy.TENSOR: + if observed is not None: + if strategy == QuantizationStrategy.TENSOR: # re-calculate scale and zero point, update the stored value self._scale, self._zero_point = self.calculate_qparams(observed) - elif self.quantization_args.strategy in ( + elif strategy in ( QuantizationStrategy.TENSOR_GROUP, QuantizationStrategy.GROUP, ): - rows = observed.shape[0] - columns = observed.shape[1] - num_groups = int(ceil(columns / group_size)) - if num_groups * group_size != columns: - logger.bind(log_once=True).warning( - "Attempting to quantize a module weight whose columns " - f"({columns}) are not divisible by group_size ({group_size}). " - "This scheme is not supported by vLLM, please consider " - "adjusting the group_size for modules with this number of " - "columns", - ) + # should be identical implementation to first half of + # `_process_quantization` - self._scale = torch.empty( - (rows, num_groups), dtype=observed.dtype, device=observed.device - ) + # get shapes + assert observed.ndim >= 2 + rows, columns = observed.shape[-2:] + group_size = self.quantization_args.group_size + num_groups = strategy_cdiv(columns, group_size, strategy) + + # FP4: cast zp type if is_fp4(quantization_args=self.quantization_args): zp_dtype = FP8_E4M3_DATA.dtype else: zp_dtype = self.quantization_args.pytorch_dtype() + # allocate qparams + self._scale = torch.empty( + (rows, num_groups), dtype=observed.dtype, device=observed.device + ) self._zero_point = torch.empty( (rows, num_groups), dtype=zp_dtype, device=observed.device ) - # support column-order (default) quantization as well as other orderings - # such as activation ordering. Below checks if g_idx has initialized - is_column_order = g_idx is None or -1 in g_idx - if is_column_order: - group_sizes = torch.full((num_groups,), group_size, dtype=torch.int) - else: - group_indices, group_sizes = torch.unique(g_idx, return_counts=True) - group_sizes = group_sizes[torch.argsort(group_indices)] - - observed = observed.index_select(-1, g_idx) + # permute groups + if g_idx is not None: + perm = torch.argsort(g_idx) + observed = observed.index_select(-1, perm) # TODO: experiment with vectorizing for loop for performance + # all reduce all dims except the second to last one end = 0 - for group_index, group_count in enumerate(group_sizes): + for group_index in range(num_groups): start = end - end = start + group_count + end = start + group_size scale, zero_point = self.get_qparams_along_dim( - observed[:, start:end], - 0, + observed[..., start:end], + dim=-2, tensor_id=group_index, global_scale=global_scale, ) @@ -188,11 +181,11 @@ def get_qparams( self._scale[:, group_index] = scale.squeeze(1) self._zero_point[:, group_index] = zero_point.squeeze(1) - elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: - # assume observed is transposed, because its the output, hence use dim 0 - self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) + elif strategy == QuantizationStrategy.CHANNEL: + # all reduce all dims except the second to last one + self._scale, self._zero_point = self.get_qparams_along_dim(observed, -2) - elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: + elif strategy == QuantizationStrategy.TOKEN: # use dim 1, assume the obsersed.shape = [batch, token, hidden] # should be batch, token self._scale, self._zero_point = self.get_qparams_along_dim( @@ -200,10 +193,10 @@ def get_qparams( dim={0, 1}, ) - elif self.quantization_args.strategy == QuantizationStrategy.BLOCK: + elif strategy == QuantizationStrategy.BLOCK: # Block-wise quantization: one scale/zero_point per block of shape # [block_rows, block_cols] - rows, cols = observed.shape[:2] + rows, cols = observed.shape[-2:] bs = self.quantization_args.block_structure if not ( isinstance(bs, (list, tuple)) @@ -215,8 +208,8 @@ def get_qparams( f"Must be a list of two ints [rows, cols]." ) block_rows, block_cols = bs - num_br = int(ceil(rows / block_rows)) - num_bc = int(ceil(cols / block_cols)) + num_br = strategy_cdiv(rows, block_rows, strategy) + num_bc = strategy_cdiv(cols, block_cols, strategy) # allocate per-block scale and zero_point self._scale = torch.empty( @@ -255,15 +248,20 @@ def get_qparams( def get_qparams_along_dim( self, - observed, + observed: torch.Tensor, dim: Union[int, Iterable[int]], tensor_id: Optional[Any] = None, global_scale: Optional[Tensor] = None, ): + # cast to set if isinstance(dim, int): dim = [dim] dim = set(dim) + # convert negative dims + dim = [d if d >= 0 else observed.ndim + d for d in dim] + + # reduce all dimensions except the the one passed as argument to this function reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) return self.calculate_qparams( observed, diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index fb49ba5da..87065ddc7 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -294,8 +294,46 @@ def test_static_weight_quantization( 0.06, ), # channel is not supported, but is in principle equivalent to token/tensor - # group is not yet supported - # tensor_group is not yet supported + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=3, + observer="minmax", + ), + { + "default": torch.tensor([[0]]), + 1: torch.tensor([[3]]), + }, + { + "default": torch.tensor([[2]]), + 1: torch.tensor([[5]]), + }, + torch.tensor([[0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875]]), + 0.04, + ), + ( + QuantizationArgs( + num_bits=4, + type="float", # tensor group requires FP4 + symmetric=True, + strategy="tensor_group", + group_size=3, + observer="minmax", + ), + { + "default": torch.tensor([[0]]), + 1: torch.tensor([[3]]), + }, + { + "default": torch.tensor([[2]]), + 1: torch.tensor([[5]]), + }, + torch.tensor([[0.0000, 0.9766, 1.9531, 3.3125, 3.3125, 4.9688]]), + 0.1, + ), # block is not supported, but is in principle similar to group ], )