diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 6e533cc1a..6dae860ea 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,4 +1,5 @@ import inspect +import warnings from typing import Dict, List, Optional, Tuple, Union import torch @@ -183,25 +184,20 @@ def validate_model_after(model: "AWQModifier") -> "AWQModifier": model._group_size = next(iter(group_size_set)) - in_num_bits_set = set( - group.input_activations.num_bits + num_bits_set = { + act.num_bits for group in config.config_groups.values() - if group.input_activations is not None - ) - assert len(in_num_bits_set) == 0 or in_num_bits_set == {16}, ( - "AWQ activations must be 16-bit precision, " - f"input activations {in_num_bits_set} not allowed" - ) - - out_num_bits_set = set( - group.output_activations.num_bits - for group in config.config_groups.values() - if group.output_activations is not None - ) - assert len(out_num_bits_set) == 0 or out_num_bits_set == {16}, ( - "AWQ activations must be 16-bit precision, " - f"output activations {out_num_bits_set} not allowed" - ) + for act in (group.input_activations, group.output_activations) + if act is not None + } + if not (len(num_bits_set) == 0 or num_bits_set == {16}): + warnings.warn( + "A strategy including activation quantization was detected. " + "AWQ was originally intended for weight-only quantization. " + "Lower-precision activations are an experimental feature, and " + "overall performance may be poor. If it is, consider using " + "`W4A16` or `W4A16_ASYM` quantization schemes instead." + ) return model diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index a4adfbdac..fe983c882 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -117,9 +117,6 @@ def test_set_resolved_mappings(): @pytest.mark.unit def test_validate(): - with pytest.raises(ValidationError): - AWQModifier(scheme="W8A8") - with pytest.raises(ValidationError): AWQModifier( config_groups={