Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/llmcompressor/pipelines/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from torch.utils.data.dataloader import DataLoader

from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.pruning.sparsegpt.sgpt_base import SparsityModifierBase
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier

if TYPE_CHECKING:
from llmcompressor.args.dataset_arguments import DatasetArguments
Expand Down Expand Up @@ -55,10 +58,19 @@ def from_modifiers(

@staticmethod
def _infer_pipeline(modifiers: list[Modifier]) -> str:
# only in the case of weight-only qmod quantization can we skip calibration
if len(modifiers) == 1 and isinstance(modifiers[0], QuantizationModifier):
config = modifiers[0].resolve_quantization_config()
if not config.requires_calibration_data():
return "datafree"
def _modifier_requires_calibration(modifier: Modifier):
if isinstance(
modifier,
(SmoothQuantModifier, SparsityModifierBase, GPTQModifier, AWQModifier),
):
return True
elif isinstance(modifier, QuantizationModifier):
config = modifier.resolve_quantization_config()
return config.requires_calibration_data()
else:
return False

return "sequential"
if any(_modifier_requires_calibration(modifier) for modifier in modifiers):
return "sequential"
else:
return "datafree"
35 changes: 35 additions & 0 deletions tests/llmcompressor/pipelines/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.pruning import SparseGPTModifier, WandaPruningModifier
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.modifiers.transform import QuIPModifier, SpinQuantModifier
from llmcompressor.pipelines import (
CalibrationPipeline,
DataFreePipeline,
SequentialPipeline,
)


@pytest.mark.parametrize(
"modifiers,exp_pipeline",
[
([QuantizationModifier(scheme="FP8")], SequentialPipeline),
([QuantizationModifier(scheme="W4A16")], DataFreePipeline),
([GPTQModifier(scheme="FP8")], SequentialPipeline),
([GPTQModifier(scheme="W4A16")], SequentialPipeline),
([SmoothQuantModifier(), GPTQModifier(scheme="W4A16")], SequentialPipeline),
([AWQModifier(scheme="W4A16")], SequentialPipeline),
([AWQModifier(scheme="FP8")], SequentialPipeline),
([SparseGPTModifier(sparsity=1.0)], SequentialPipeline),
([WandaPruningModifier(sparsity=1.0)], SequentialPipeline),
([QuIPModifier()], DataFreePipeline),
([SpinQuantModifier()], DataFreePipeline),
([QuIPModifier(), QuantizationModifier(scheme="FP8")], SequentialPipeline),
([QuIPModifier(), QuantizationModifier(scheme="W4A16")], DataFreePipeline),
],
)
def test_infer_pipeline(modifiers, exp_pipeline):
pipeline = CalibrationPipeline.from_modifiers(modifiers)
assert isinstance(pipeline, exp_pipeline)
Loading