diff --git a/src/llmcompressor/pipelines/registry.py b/src/llmcompressor/pipelines/registry.py index 60cf37edae..e65bb5b259 100644 --- a/src/llmcompressor/pipelines/registry.py +++ b/src/llmcompressor/pipelines/registry.py @@ -7,7 +7,10 @@ from torch.utils.data.dataloader import DataLoader from llmcompressor.modifiers import Modifier -from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.awq import AWQModifier +from llmcompressor.modifiers.pruning.sparsegpt.sgpt_base import SparsityModifierBase +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from llmcompressor.modifiers.smoothquant import SmoothQuantModifier if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -55,10 +58,19 @@ def from_modifiers( @staticmethod def _infer_pipeline(modifiers: list[Modifier]) -> str: - # only in the case of weight-only qmod quantization can we skip calibration - if len(modifiers) == 1 and isinstance(modifiers[0], QuantizationModifier): - config = modifiers[0].resolve_quantization_config() - if not config.requires_calibration_data(): - return "datafree" + def _modifier_requires_calibration(modifier: Modifier): + if isinstance( + modifier, + (SmoothQuantModifier, SparsityModifierBase, GPTQModifier, AWQModifier), + ): + return True + elif isinstance(modifier, QuantizationModifier): + config = modifier.resolve_quantization_config() + return config.requires_calibration_data() + else: + return False - return "sequential" + if any(_modifier_requires_calibration(modifier) for modifier in modifiers): + return "sequential" + else: + return "datafree" diff --git a/tests/llmcompressor/pipelines/test_registry.py b/tests/llmcompressor/pipelines/test_registry.py new file mode 100644 index 0000000000..39c9ddc65a --- /dev/null +++ b/tests/llmcompressor/pipelines/test_registry.py @@ -0,0 +1,35 @@ +import pytest + +from llmcompressor.modifiers.awq import AWQModifier +from llmcompressor.modifiers.pruning import SparseGPTModifier, WandaPruningModifier +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier +from llmcompressor.modifiers.smoothquant import SmoothQuantModifier +from llmcompressor.modifiers.transform import QuIPModifier, SpinQuantModifier +from llmcompressor.pipelines import ( + CalibrationPipeline, + DataFreePipeline, + SequentialPipeline, +) + + +@pytest.mark.parametrize( + "modifiers,exp_pipeline", + [ + ([QuantizationModifier(scheme="FP8")], SequentialPipeline), + ([QuantizationModifier(scheme="W4A16")], DataFreePipeline), + ([GPTQModifier(scheme="FP8")], SequentialPipeline), + ([GPTQModifier(scheme="W4A16")], SequentialPipeline), + ([SmoothQuantModifier(), GPTQModifier(scheme="W4A16")], SequentialPipeline), + ([AWQModifier(scheme="W4A16")], SequentialPipeline), + ([AWQModifier(scheme="FP8")], SequentialPipeline), + ([SparseGPTModifier(sparsity=1.0)], SequentialPipeline), + ([WandaPruningModifier(sparsity=1.0)], SequentialPipeline), + ([QuIPModifier()], DataFreePipeline), + ([SpinQuantModifier()], DataFreePipeline), + ([QuIPModifier(), QuantizationModifier(scheme="FP8")], SequentialPipeline), + ([QuIPModifier(), QuantizationModifier(scheme="W4A16")], DataFreePipeline), + ], +) +def test_infer_pipeline(modifiers, exp_pipeline): + pipeline = CalibrationPipeline.from_modifiers(modifiers) + assert isinstance(pipeline, exp_pipeline)