7
7
from torch .utils .data .dataloader import DataLoader
8
8
9
9
from llmcompressor .modifiers import Modifier
10
- from llmcompressor .modifiers .awq import AWQModifier
11
- from llmcompressor .modifiers .obcq .sgpt_base import SparsityModifierBase
12
- from llmcompressor .modifiers .quantization import GPTQModifier , QuantizationMixin
13
- from llmcompressor .modifiers .smoothquant import SmoothQuantModifier
10
+ from llmcompressor .modifiers .quantization import QuantizationModifier
14
11
15
12
if TYPE_CHECKING :
16
13
from llmcompressor .args .dataset_arguments import DatasetArguments
17
14
18
15
__all__ = ["CalibrationPipeline" ]
19
16
20
- CALIBRATION_MODIFIERS = (
21
- SmoothQuantModifier ,
22
- AWQModifier ,
23
- GPTQModifier ,
24
- SparsityModifierBase ,
25
- )
26
-
27
17
28
18
class CalibrationPipeline (ABC , RegistryMixin ):
29
19
@staticmethod
@@ -48,7 +38,7 @@ def from_modifiers(
48
38
:return: CalibrationPipeline instance to be called with data (if not datafree)
49
39
"""
50
40
user = standardize_lookup_name (user ) if user else None
51
- inferred = standardize_lookup_name (cls ._validate_infer_pipeline (modifiers ))
41
+ inferred = standardize_lookup_name (cls ._infer_pipeline (modifiers ))
52
42
independent = standardize_lookup_name ("independent" )
53
43
54
44
if user == independent :
@@ -64,30 +54,11 @@ def from_modifiers(
64
54
return cls .load_from_registry (pipeline )
65
55
66
56
@staticmethod
67
- def _validate_infer_pipeline (modifiers : List [Modifier ]) -> str :
68
- if any (isinstance (modifier , CALIBRATION_MODIFIERS ) for modifier in modifiers ):
69
- return "sequential"
70
-
71
- active_qmods = _get_active_quant_modifiers (modifiers )
72
- if len (active_qmods ) > 1 :
73
- raise ValueError (
74
- f"Recipe contains more than one active quantization config "
75
- f"({ active_qmods } ). These configs may be conflicting, Please modify "
76
- "your recipe to use at most one quantization config"
77
- )
78
-
79
- if len (active_qmods ) == 1 :
80
- quant_modifier = active_qmods [0 ]
81
- config = quant_modifier .resolve_quantization_config ()
82
- if config .requires_calibration_data ():
83
- return "sequential"
84
-
85
- return "datafree"
86
-
87
-
88
- def _get_active_quant_modifiers (modifiers : List [Modifier ]) -> List [QuantizationMixin ]:
89
- return [
90
- modifier
91
- for modifier in modifiers
92
- if isinstance (modifier , QuantizationMixin ) and modifier .has_config ()
93
- ]
57
+ def _infer_pipeline (modifiers : List [Modifier ]) -> str :
58
+ # only in the case of weight-only qmod quantization can we skip calibration
59
+ if len (modifiers ) == 1 and isinstance (modifiers [0 ], QuantizationModifier ):
60
+ config = modifiers [0 ].resolve_quantization_config ()
61
+ if not config .requires_calibration_data ():
62
+ return "datafree"
63
+
64
+ return "sequential"
0 commit comments