11import pytest
22
3- from llmcompressor .modifiers .quantization import QuantizationModifier , GPTQModifier
4- from llmcompressor .modifiers .smoothquant import SmoothQuantModifier
53from llmcompressor .modifiers .awq import AWQModifier
6- from llmcompressor .modifiers .pruning import SparseGPTQModifier , WandaPruningModifier
4+ from llmcompressor .modifiers .pruning import SparseGPTModifier , WandaPruningModifier
5+ from llmcompressor .modifiers .quantization import GPTQModifier , QuantizationModifier
6+ from llmcompressor .modifiers .smoothquant import SmoothQuantModifier
77from llmcompressor .modifiers .transform import QuIPModifier , SpinQuantModifier
8- from llmcompressor .pipelines import CalibrationPipeline , SequentialPipeline , DataFreePipeline
8+ from llmcompressor .pipelines import (
9+ CalibrationPipeline ,
10+ DataFreePipeline ,
11+ SequentialPipeline ,
12+ )
913
1014
11- @pytest .mark .parametrize ("modifiers" , [
12- ([QuantizationModifier (scheme = "FP8" )], SequentialPipeline )
13- ([QuantizationModifier (scheme = "W4A16" )], DataFreePipeline )
14- ([GPTQModifier (scheme = "FP8" )], SequentialPipeline )
15- ([GPTQModifier (scheme = "W4A16" )], DataFreePipeline )
16- ([SmoothQuantModifier (), GPTQModifier (scheme = "W4A16" )], SequentialPipeline ),
17- ([AWQModifier (scheme = "W4A16" )], SequentialPipeline )
18- ([AWQModifier (scheme = "FP8" )], SequentialPipeline )
19- ([SparseGPTQModifier ()], SequentialPipeline )
20- ([WandaPruningModifier ()], SequentialPipeline )
21- ([QuIPModifier ()], DataFreePipeline )
22- ([SpinQuantModifier ()], DataFreePipeline )
23- ([QuIPModifier (), QuantizationModifier (scheme = "FP8" )], SequentialPipeline )
24- ([QuIPModifier (), QuantizationModifier (scheme = "W4A16" )], DataFreePipeline )
25- ])
15+ @pytest .mark .parametrize (
16+ "modifiers,exp_pipeline" ,
17+ [
18+ ([QuantizationModifier (scheme = "FP8" )], SequentialPipeline ),
19+ ([QuantizationModifier (scheme = "W4A16" )], DataFreePipeline ),
20+ ([GPTQModifier (scheme = "FP8" )], SequentialPipeline ),
21+ ([GPTQModifier (scheme = "W4A16" )], SequentialPipeline ),
22+ ([SmoothQuantModifier (), GPTQModifier (scheme = "W4A16" )], SequentialPipeline ),
23+ ([AWQModifier (scheme = "W4A16" )], SequentialPipeline ),
24+ ([AWQModifier (scheme = "FP8" )], SequentialPipeline ),
25+ ([SparseGPTModifier (sparsity = 1.0 )], SequentialPipeline ),
26+ ([WandaPruningModifier (sparsity = 1.0 )], SequentialPipeline ),
27+ ([QuIPModifier ()], DataFreePipeline ),
28+ ([SpinQuantModifier ()], DataFreePipeline ),
29+ ([QuIPModifier (), QuantizationModifier (scheme = "FP8" )], SequentialPipeline ),
30+ ([QuIPModifier (), QuantizationModifier (scheme = "W4A16" )], DataFreePipeline ),
31+ ],
32+ )
2633def test_infer_pipeline (modifiers , exp_pipeline ):
2734 pipeline = CalibrationPipeline .from_modifiers (modifiers )
28- assert isinstance (pipeline , exp_pipeline )
35+ assert isinstance (pipeline , exp_pipeline )
0 commit comments