diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 1d594d46a..31e11315b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.quantization import disable_quantization +from compressed_tensors.quantization import QuantizationType, disable_quantization from compressed_tensors.utils import ( align_modules, get_execution_device, @@ -126,6 +126,7 @@ class AWQModifier(Modifier, QuantizationMixin): # Private vars set during validation _num_bits: Optional[int] = PrivateAttr(default=None) + _activation_bits: int = PrivateAttr(default=16) _symmetric: Optional[bool] = PrivateAttr(default=None) _group_size: Optional[int] = PrivateAttr(default=None) @@ -189,6 +190,18 @@ def validate_model_after(model: "AWQModifier") -> "AWQModifier": if act is not None } if not (len(num_bits_set) == 0 or num_bits_set == {16}): + num_bits_type = { + act.type + for group in config.config_groups.values() + for act in (group.input_activations, group.output_activations) + if act is not None + } + assert ( + next(iter(num_bits_type)) == QuantizationType.FLOAT + ), "In AWQ, lower-precision activation quantization must be float" + + model._activation_bits = next(iter(num_bits_set)) + warnings.warn( "A strategy including activation quantization was detected. " "AWQ was originally intended for weight-only quantization. " @@ -612,16 +625,26 @@ def _compute_best_scale( # Q(W * s) for linear in linears2scale: linear.weight.mul_(_scalesview) - update_offload_parameter( - linear, - "weight", + scaled_weight = ( _pseudo_quantize_tensor( w=linear.weight.data, symmetric=self._symmetric, bit_width=self._num_bits, group_size=self._group_size, )[0] - / _scalesview, + / _scalesview + ) + + # fp8 activation simulation + if self._activation_bits == 8: + scaled_weight = scaled_weight.to(torch.float8_e4m3fn).to( + torch.float16 + ) + + update_offload_parameter( + linear, + "weight", + scaled_weight, ) # W * X diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 8119ca90c..a7b644a19 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -1,6 +1,10 @@ import pytest import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationType, +) from pydantic import ValidationError from llmcompressor.modifiers.awq import AWQMapping, AWQModifier @@ -154,6 +158,25 @@ def test_validate(): } ) + with pytest.raises(ValidationError): + AWQModifier( + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + group_size=128, + ), + input_activations=QuantizationArgs( + num_bits=8, type=QuantizationType.INT + ), + output_activations=QuantizationArgs( + num_bits=8, type=QuantizationType.INT + ), + ), + } + ) + # valid configuration AWQModifier( config_groups={ @@ -165,6 +188,16 @@ def test_validate(): targets=["Linear"], weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False), ), + "group_2": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=4, group_size=128, symmetric=False), + input_activations=QuantizationArgs( + num_bits=8, type=QuantizationType.FLOAT + ), + output_activations=QuantizationArgs( + num_bits=8, type=QuantizationType.FLOAT + ), + ), } )