diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index 06edf3c59..67cf9bab3 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -1,22 +1,70 @@ -from typing import Optional +from typing import List, Optional from compressed_tensors import CompressionFormat from compressed_tensors.config import SparsityStructure -from compressed_tensors.quantization import QuantizationStrategy, QuantizationType +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) from compressed_tensors.quantization.utils import is_module_quantized -__all__ = ["infer_quantization_format"] +__all__ = ["infer_per_module_quantization_format"] -def infer_quantization_format( +def _get_quant_method( + input_args: QuantizationArgs, + weight_args: QuantizationArgs, + sparsity_structure: Optional[str] = None, +): + is_24_structure = ( + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR + ) + is_weight_only = weight_args is not None and input_args is None + + if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: + return CompressionFormat.nvfp4_pack_quantized + + if is_weight_only: # w4a16 and w8a16 + is_valid_pack = ( + weight_args.num_bits in [4, 8] + and weight_args.type == QuantizationType.INT.value + ) + if not is_valid_pack: # packing only valid for int4 and int 8 + return CompressionFormat.naive_quantized + if is_24_structure: + if ( + weight_args.strategy is not QuantizationStrategy.CHANNEL.value + and weight_args.strategy is not QuantizationStrategy.GROUP.value + ): + # marlin24 kernel only applicable for channel/group quantization + return CompressionFormat.pack_quantized + return CompressionFormat.marlin_24 + return CompressionFormat.pack_quantized + + else: # w8a8 float and int + if ( + weight_args.type == QuantizationType.FLOAT.value + and weight_args.num_bits == 8 + ): + return CompressionFormat.float_quantized + if weight_args.type == QuantizationType.INT.value: + return CompressionFormat.int_quantized + + return CompressionFormat.naive_quantized + + +def infer_per_module_quantization_format( model, quantization_format: Optional[str] = None, save_compressed: bool = False, sparsity_structure: Optional[str] = None, -) -> str: +) -> Optional[List[str]]: """ Infers the quantization format for a model based on its state and provided - compression arguments. + compression arguments. Also updates thhe quantization_scheme.format value + based on the inferred format. Returns the unique list of formats in the model + or None if empty list The following table outlines the possible quantization and sparsity formats along with their corresponding compressor formats: @@ -27,6 +75,8 @@ def infer_quantization_format( +---------------+----------+----------------------+---------------------+ | W8A8 - int | None | int_quantized | Dense | | W8A8 - float | None | float_quantized | Dense | + | W4A16 - float | None | nvfp4_pack_quantized | Dense | + | W4A4 - float | None | nvfp4_pack_quantized | Dense | | W4A16 - int | None | pack_quantized | Dense | | W8A16 - int | None | pack_quantized | Dense | | W8A16 - float | None | naive_quantized | Dense | @@ -44,74 +94,26 @@ def infer_quantization_format( :param save_compressed: used to infer a quantization format if None is provided :return compression format appropriate for model """ - if quantization_format is not None: - return quantization_format - - weight_args, input_args = _get_unique_quant_args(model) - if len(weight_args) <= 0: - return None - - if save_compressed: - is_24_structure = ( - SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR - ) - is_weight_only = len(input_args) == 0 and len(weight_args) > 0 - - if ( - weight_args[0].num_bits == 4 - and weight_args[0].type == QuantizationType.FLOAT.value - ): - return CompressionFormat.nvfp4_pack_quantized - if is_weight_only: # w4a16 and w8a16 - is_valid_pack = all( - weight_arg.num_bits in [4, 8] - and weight_arg.type == QuantizationType.INT.value - for weight_arg in weight_args - ) - if not is_valid_pack: # packing only valid for int4 and int 8 - return CompressionFormat.naive_quantized - if is_24_structure: - for arg in weight_args: - if ( - arg.strategy is not QuantizationStrategy.CHANNEL.value - and arg.strategy is not QuantizationStrategy.GROUP.value - ): - # marlin24 kernel only applicable for channel/group quantization - return CompressionFormat.pack_quantized - return CompressionFormat.marlin_24 - return CompressionFormat.pack_quantized - else: # w8a8 float and int - if len(weight_args) == 1: - if ( - weight_args[0].type == QuantizationType.FLOAT.value - and weight_args[0].num_bits == 8 - ): - return CompressionFormat.float_quantized - if weight_args[0].type == QuantizationType.INT.value: - return CompressionFormat.int_quantized - - return CompressionFormat.naive_quantized - else: - # format will be inferred from config + if not save_compressed: return None + if quantization_format: + return [quantization_format] -def _get_unique_quant_args(model): - """ - Gets a list of all the unique quantization settings present in model - """ - quant_info_weight = [] - quant_info_inputs = [] + unique_formats = [] for submodule in model.modules(): if is_module_quantized(submodule): weight_scheme = submodule.quantization_scheme.weights input_scheme = submodule.quantization_scheme.input_activations - if weight_scheme is not None: - if weight_scheme not in quant_info_weight: - quant_info_weight.append(weight_scheme) - if input_scheme is not None: - if input_scheme not in quant_info_inputs: - quant_info_inputs.append(input_scheme) - - return quant_info_weight, quant_info_inputs + if weight_scheme is None: + continue # no weight quant - nothing to compress + compression_format = _get_quant_method( + input_scheme, weight_scheme, sparsity_structure + ) + submodule.quantization_scheme.format = compression_format.value + if compression_format not in unique_formats: + unique_formats.append(compression_format) + if len(unique_formats) > 0: + return unique_formats + return None diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index da2ab4230..756c02fe4 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -19,7 +19,7 @@ from llmcompressor.core import active_session from llmcompressor.pytorch.model_load.helpers import copy_python_files_from_model_cache from llmcompressor.transformers.compression.quantization_format import ( - infer_quantization_format, + infer_per_module_quantization_format, ) from llmcompressor.transformers.compression.sparsity_metadata_config import ( SparsityConfigMetadata, @@ -228,13 +228,15 @@ def get_model_compressor( SparsityConfigMetadata.infer_sparsity_structure(model) ) - quantization_format: Optional[CompressionFormat] = infer_quantization_format( - model=model, - quantization_format=quantization_format, - save_compressed=save_compressed, - sparsity_structure=None - if sparsity_config is None - else sparsity_config.sparsity_structure, + quantization_format: Optional[CompressionFormat] = ( + infer_per_module_quantization_format( + model=model, + quantization_format=quantization_format, + save_compressed=save_compressed, + sparsity_structure=None + if sparsity_config is None + else sparsity_config.sparsity_structure, + ) ) return ModelCompressor.from_pretrained_model( diff --git a/tests/llmcompressor/transformers/compression/test_infer_quant_format.py b/tests/llmcompressor/transformers/compression/test_infer_quant_format.py index 1eb3bf202..c671ac2e0 100644 --- a/tests/llmcompressor/transformers/compression/test_infer_quant_format.py +++ b/tests/llmcompressor/transformers/compression/test_infer_quant_format.py @@ -2,7 +2,7 @@ from compressed_tensors.quantization import preset_name_to_scheme from llmcompressor.transformers.compression.quantization_format import ( - infer_quantization_format, + infer_per_module_quantization_format, ) from tests.llmcompressor.pytorch.helpers import LinearNet @@ -25,7 +25,7 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format): for _, module in dummy_model.named_modules(): module.quantization_scheme = quant_scheme - inferred_format = infer_quantization_format( + inferred_format = infer_per_module_quantization_format( dummy_model, save_compressed=True, sparsity_structure=sparsity_structure ) - assert inferred_format.value == expected_format + assert inferred_format[0].value == expected_format