diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index aa9e1caab..ea325c7dc 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -10,7 +10,6 @@ ) from compressed_tensors.quantization.utils import is_fp4 from compressed_tensors.registry.registry import RegistryMixin -from compressed_tensors.utils import safe_permute from loguru import logger from torch import FloatTensor, IntTensor, Tensor @@ -51,8 +50,12 @@ def forward( :return: tuple of scale and zero point based on last observed value """ self.record_observed_tokens(observed) + if should_calculate_gparam: + # NOTE: this function updates running min/max values, which leads to + # running values updating twice return self.get_gparam(observed=observed) + return self.get_qparams( observed=observed, g_idx=g_idx, @@ -168,8 +171,7 @@ def get_qparams( group_indices, group_sizes = torch.unique(g_idx, return_counts=True) group_sizes = group_sizes[torch.argsort(group_indices)] - perm = torch.argsort(g_idx) - observed = safe_permute(observed, perm, dim=1) + observed = observed.index_select(-1, g_idx) # TODO: experiment with vectorizing for loop for performance end = 0 diff --git a/src/llmcompressor/observers/min_max.py b/src/llmcompressor/observers/min_max.py index ce5c0e779..7f94f4788 100644 --- a/src/llmcompressor/observers/min_max.py +++ b/src/llmcompressor/observers/min_max.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple +from typing import Any, Iterable, Optional, Tuple, Union import torch from compressed_tensors.quantization.quant_args import QuantizationArgs @@ -58,6 +58,8 @@ def calculate_updated_min_max( # early stopping, save some computation and memory if self.averaging_constant == 1.0: + self.min_val[tensor_id] = min_val + self.max_val[tensor_id] = max_val return min_val, max_val running_min_val = self.min_val.get(tensor_id, None) @@ -85,7 +87,8 @@ def calculate_gparam(self, observed: torch.Tensor) -> torch.Tensor: :param observed: observed tensor to calculate quantization parameters for :return: updated global scale derived from the observed tensor """ - + # NOTE: this function updates running min/max values, which leads to + # running values updating twice updated_min_val, updated_max_val = self.calculate_updated_min_max( observed=observed ) @@ -126,14 +129,23 @@ def calculate_qparams( def get_qparams_along_dim( self, observed: torch.Tensor, - dim: int, + dim: Union[int, Iterable[int]], tensor_id: Optional[Any] = None, global_scale: Optional[torch.Tensor] = None, ): """ Calculate quantization parameters along the specified dimension """ - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) + # cast to set + if isinstance(dim, int): + dim = [dim] + dim = set(dim) + + # convert negative dims + dim = [d if d >= 0 else observed.ndim + d for d in dim] + + # reduce all dimensions except the ones passed as argument to this function + reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) return self.calculate_qparams( observed, reduce_dims=reduce_dims, diff --git a/tests/llmcompressor/modifiers/calibration/test_observers.py b/tests/llmcompressor/modifiers/calibration/test_observers.py index a742a48b2..fb49ba5da 100644 --- a/tests/llmcompressor/modifiers/calibration/test_observers.py +++ b/tests/llmcompressor/modifiers/calibration/test_observers.py @@ -6,7 +6,12 @@ initialize_module_for_quantization, ) -from llmcompressor.modifiers.quantization.calibration import initialize_observer +from llmcompressor.modifiers.quantization.calibration import ( + calibrate_input_hook, + initialize_observer, + update_weight_global_scale, + update_weight_zp_scale, +) @pytest.mark.parametrize( @@ -59,3 +64,277 @@ def test_observers_update(shape, group_size, actorder): def assert_alike(a, b): assert a.dtype == b.dtype assert a.shape == b.shape + + +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", # equivalent to token + observer="minmax", + ), + {"default": torch.tensor(0.0)}, + {"default": torch.tensor(23.0)}, + torch.tensor( + [ + [0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250], + [6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500], + [12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.85, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="channel", + observer="minmax", + ), + {"default": torch.tensor([[0], [6], [12], [18]])}, + {"default": torch.tensor([[5], [11], [17], [23]])}, + torch.tensor( + [ + [0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875], + [5.8750, 7.3438, 7.3438, 8.8125, 10.2500, 10.2500], + [11.3125, 13.6250, 13.6250, 15.8750, 15.8750, 15.8750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.45, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="group", + group_size=3, + observer="minmax", + ), + { + "default": torch.tensor([[0], [6], [12], [18]]), + 1: torch.tensor([[3], [9], [15], [21]]), + }, + { + "default": torch.tensor([[2], [8], [14], [20]]), + 1: torch.tensor([[5], [11], [17], [23]]), + }, + torch.tensor( + [ + [0.0000, 1.0703, 1.8750, 2.6719, 4.0000, 4.6875], + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], + [11.1875, 13.0625, 13.0625, 15.8750, 15.8750, 15.8750], + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], + ], + ), + 0.45, + ), + ( + QuantizationArgs( + num_bits=4, + type="float", # tensor group requires FP4 + symmetric=True, + strategy="tensor_group", # requires float4 + group_size=3, + observer="minmax", + ), + { + "default": torch.tensor([[0], [6], [12], [18]]), + 1: torch.tensor([[3], [9], [15], [21]]), + }, + { + "default": torch.tensor([[2], [8], [14], [20]]), + 1: torch.tensor([[5], [11], [17], [23]]), + }, + torch.tensor( + [ + [0.0000, 1.0234, 2.0469, 3.2812, 3.2812, 4.9375], + [5.4688, 8.1875, 8.1875, 10.6875, 10.6875, 10.6875], + [9.8750, 14.7500, 14.7500, 16.3750, 16.3750, 16.3750], + [19.7500, 19.7500, 19.7500, 23.0000, 23.0000, 23.0000], + ], + ), + 1.1, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="block", + block_structure=[2, 3], + observer="minmax", + ), + { + "block_0_0": torch.tensor([[0]]), + "block_0_1": torch.tensor([[3]]), + "block_1_0": torch.tensor([[12]]), + "block_1_1": torch.tensor([[15]]), + }, + { + "block_0_0": torch.tensor([[8]]), + "block_0_1": torch.tensor([[11]]), + "block_1_0": torch.tensor([[20]]), + "block_1_1": torch.tensor([[23]]), + }, + torch.tensor( + [ + [0.0000, 1.0703, 2.1406, 2.9375, 4.4062, 4.4062], + [6.4375, 7.5000, 7.5000, 8.8125, 10.2500, 10.2500], + [10.6875, 13.3750, 13.3750, 15.3125, 15.3125, 18.3750], + [18.7500, 18.7500, 18.7500, 21.5000, 21.5000, 21.5000], + ], + ), + 0.5, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="token", # equivalent to tensor + observer="minmax", + ), + {"default": torch.tensor(0.0)}, + {"default": torch.tensor(23.0)}, + torch.tensor( + [ + [0.0000, 0.0000, 3.0625, 3.0625, 3.0625, 6.1250], + [6.1250, 6.1250, 9.1875, 9.1875, 9.1875, 12.2500], + [12.2500, 12.2500, 15.3125, 15.3125, 15.3125, 18.3750], + [18.3750, 18.3750, 21.5000, 21.5000, 21.5000, 21.5000], + ], + dtype=torch.bfloat16, + ), + 0.85, + ), + ], +) +def test_static_weight_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + weight = tensor([[ 0, 1, 2, 3, 4, 5], + [ 6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17], + [18, 19, 20, 21, 22, 23]]) + """ + # set up weight + input_size, output_size = 6, 4 + linear = torch.nn.Linear(input_size, output_size, bias=False) + linear.weight.data = torch.arange( + input_size * output_size, dtype=torch.bfloat16 + ).reshape(output_size, input_size) + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], weights=args) + initialize_module_for_quantization(linear, scheme) + assert getattr(linear, "quantization_scheme") is scheme + + # calibrate quantization parameters + initialize_observer(linear, "weight") + update_weight_global_scale(linear) + update_weight_zp_scale(linear) + + observer = getattr(linear, "weight_observer") + assert ( + observer.min_val.keys() + == observer.max_val.keys() + == exp_min_val.keys() + == exp_max_val.keys() + ) + for key in observer.min_val.keys(): + assert torch.equal(observer.min_val[key], exp_min_val[key]) + assert torch.equal(observer.max_val[key], exp_max_val[key]) + + # forward pass + input = torch.eye(input_size, dtype=torch.bfloat16) + output = linear(input) + + assert torch.allclose(output.T, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output.T, linear.weight) <= exp_loss + + +@pytest.mark.parametrize( + "args,exp_min_val,exp_max_val,exp_quant,exp_loss", + [ + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="tensor", # equivalent to token + observer="minmax", + ), + {"default": torch.tensor(0.0)}, + {"default": torch.tensor(5.0)}, + torch.tensor([[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875]]), + 0.06, + ), + ( + QuantizationArgs( + num_bits=4, + type="int", + symmetric=True, + strategy="token", # equivalent to tensor + observer="minmax", + ), + {"default": torch.tensor(0.0)}, + {"default": torch.tensor(5.0)}, + torch.tensor([[0.0000, 1.3359, 2.0000, 2.6719, 4.0000, 4.6875]]), + 0.06, + ), + # channel is not supported, but is in principle equivalent to token/tensor + # group is not yet supported + # tensor_group is not yet supported + # block is not supported, but is in principle similar to group + ], +) +def test_static_activation_quantization( + args, exp_min_val, exp_max_val, exp_quant, exp_loss +): + """ + input = tensor([[ 0, 1, 2, 3, 4, 5]]) + """ + # set up activation (and identity weight) + input_size = 6 + input = torch.arange(input_size, dtype=torch.bfloat16).unsqueeze(0) + linear = torch.nn.Linear(input_size, input_size, bias=False) + linear.weight.data = torch.eye(input_size, dtype=torch.bfloat16) + + # initialize quantization parameters + scheme = QuantizationScheme(targets=[], input_activations=args) + initialize_module_for_quantization(linear, scheme) + assert getattr(linear, "quantization_scheme") is scheme + + # calibrate quantization parameters + initialize_observer(linear, "input") + linear.register_forward_pre_hook(calibrate_input_hook) + + # calibration forward pass + output = linear(input) + + # check calibration + observer = getattr(linear, "input_observer") + assert ( + observer.min_val.keys() + == observer.max_val.keys() + == exp_min_val.keys() + == exp_max_val.keys() + ) + for key in observer.min_val.keys(): + assert torch.equal(observer.min_val[key], exp_min_val[key]) + assert torch.equal(observer.max_val[key], exp_max_val[key]) + + # check forward pass + assert torch.allclose(output, exp_quant.to(output.dtype)) + assert torch.nn.functional.mse_loss(output, input) <= exp_loss