Skip to content

Commit 16d3d9f

Browse files
committed
finish
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 8466a68 commit 16d3d9f

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

src/llmcompressor/pipelines/registry.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from torch.utils.data.dataloader import DataLoader
88

99
from llmcompressor.modifiers import Modifier
10+
from llmcompressor.modifiers.awq import AWQModifier
1011
from llmcompressor.modifiers.pruning.sparsegpt.sgpt_base import SparsityModifierBase
11-
from llmcompressor.modifiers.quantization import QuantizationModifier
12+
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
1213
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
1314

1415
if TYPE_CHECKING:
@@ -58,11 +59,14 @@ def from_modifiers(
5859
@staticmethod
5960
def _infer_pipeline(modifiers: list[Modifier]) -> str:
6061
def _modifier_requires_calibration(modifier: Modifier):
61-
if isinstance(modifier, QuantizationModifier):
62+
if isinstance(
63+
modifier,
64+
(SmoothQuantModifier, SparsityModifierBase, GPTQModifier, AWQModifier),
65+
):
66+
return True
67+
elif isinstance(modifier, QuantizationModifier):
6268
config = modifier.resolve_quantization_config()
6369
return config.requires_calibration_data()
64-
elif isinstance(modifier, (SmoothQuantModifier, SparsityModifierBase)):
65-
return True
6670
else:
6771
return False
6872

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,35 @@
11
import pytest
22

3-
from llmcompressor.modifiers.quantization import QuantizationModifier, GPTQModifier
4-
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
53
from llmcompressor.modifiers.awq import AWQModifier
6-
from llmcompressor.modifiers.pruning import SparseGPTQModifier, WandaPruningModifier
4+
from llmcompressor.modifiers.pruning import SparseGPTModifier, WandaPruningModifier
5+
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
6+
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
77
from llmcompressor.modifiers.transform import QuIPModifier, SpinQuantModifier
8-
from llmcompressor.pipelines import CalibrationPipeline, SequentialPipeline, DataFreePipeline
8+
from llmcompressor.pipelines import (
9+
CalibrationPipeline,
10+
DataFreePipeline,
11+
SequentialPipeline,
12+
)
913

1014

11-
@pytest.mark.parametrize("modifiers", [
12-
([QuantizationModifier(scheme="FP8")], SequentialPipeline)
13-
([QuantizationModifier(scheme="W4A16")], DataFreePipeline)
14-
([GPTQModifier(scheme="FP8")], SequentialPipeline)
15-
([GPTQModifier(scheme="W4A16")], DataFreePipeline)
16-
([SmoothQuantModifier(), GPTQModifier(scheme="W4A16")], SequentialPipeline),
17-
([AWQModifier(scheme="W4A16")], SequentialPipeline)
18-
([AWQModifier(scheme="FP8")], SequentialPipeline)
19-
([SparseGPTQModifier()], SequentialPipeline)
20-
([WandaPruningModifier()], SequentialPipeline)
21-
([QuIPModifier()], DataFreePipeline)
22-
([SpinQuantModifier()], DataFreePipeline)
23-
([QuIPModifier(), QuantizationModifier(scheme="FP8")], SequentialPipeline)
24-
([QuIPModifier(), QuantizationModifier(scheme="W4A16")], DataFreePipeline)
25-
])
15+
@pytest.mark.parametrize(
16+
"modifiers,exp_pipeline",
17+
[
18+
([QuantizationModifier(scheme="FP8")], SequentialPipeline),
19+
([QuantizationModifier(scheme="W4A16")], DataFreePipeline),
20+
([GPTQModifier(scheme="FP8")], SequentialPipeline),
21+
([GPTQModifier(scheme="W4A16")], SequentialPipeline),
22+
([SmoothQuantModifier(), GPTQModifier(scheme="W4A16")], SequentialPipeline),
23+
([AWQModifier(scheme="W4A16")], SequentialPipeline),
24+
([AWQModifier(scheme="FP8")], SequentialPipeline),
25+
([SparseGPTModifier(sparsity=1.0)], SequentialPipeline),
26+
([WandaPruningModifier(sparsity=1.0)], SequentialPipeline),
27+
([QuIPModifier()], DataFreePipeline),
28+
([SpinQuantModifier()], DataFreePipeline),
29+
([QuIPModifier(), QuantizationModifier(scheme="FP8")], SequentialPipeline),
30+
([QuIPModifier(), QuantizationModifier(scheme="W4A16")], DataFreePipeline),
31+
],
32+
)
2633
def test_infer_pipeline(modifiers, exp_pipeline):
2734
pipeline = CalibrationPipeline.from_modifiers(modifiers)
28-
assert isinstance(pipeline, exp_pipeline)
35+
assert isinstance(pipeline, exp_pipeline)

0 commit comments

Comments
 (0)