Skip to content

Commit 501056e

Browse files
committed
simplify pipeline inference logic, add comment
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 7d7b00d commit 501056e

File tree

2 files changed

+11
-40
lines changed

2 files changed

+11
-40
lines changed

src/llmcompressor/pipelines/basic/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __call__(
3838
:param dataloader: loads data for calibration
3939
:param dataset_args: dataset arguments relevant to pipelines
4040
"""
41-
dispatch_for_generation(model)
41+
dispatch_for_generation(model) # basic dispatch is identical to generation
4242
model_device = get_execution_device(model)
4343

4444
LifecycleCallbacks.calibration_epoch_start()

src/llmcompressor/pipelines/registry.py

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,13 @@
77
from torch.utils.data.dataloader import DataLoader
88

99
from llmcompressor.modifiers import Modifier
10-
from llmcompressor.modifiers.awq import AWQModifier
11-
from llmcompressor.modifiers.obcq.sgpt_base import SparsityModifierBase
12-
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationMixin
13-
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
10+
from llmcompressor.modifiers.quantization import QuantizationModifier
1411

1512
if TYPE_CHECKING:
1613
from llmcompressor.args.dataset_arguments import DatasetArguments
1714

1815
__all__ = ["CalibrationPipeline"]
1916

20-
CALIBRATION_MODIFIERS = (
21-
SmoothQuantModifier,
22-
AWQModifier,
23-
GPTQModifier,
24-
SparsityModifierBase,
25-
)
26-
2717

2818
class CalibrationPipeline(ABC, RegistryMixin):
2919
@staticmethod
@@ -48,7 +38,7 @@ def from_modifiers(
4838
:return: CalibrationPipeline instance to be called with data (if not datafree)
4939
"""
5040
user = standardize_lookup_name(user) if user else None
51-
inferred = standardize_lookup_name(cls._validate_infer_pipeline(modifiers))
41+
inferred = standardize_lookup_name(cls._infer_pipeline(modifiers))
5242
independent = standardize_lookup_name("independent")
5343

5444
if user == independent:
@@ -64,30 +54,11 @@ def from_modifiers(
6454
return cls.load_from_registry(pipeline)
6555

6656
@staticmethod
67-
def _validate_infer_pipeline(modifiers: List[Modifier]) -> str:
68-
if any(isinstance(modifier, CALIBRATION_MODIFIERS) for modifier in modifiers):
69-
return "sequential"
70-
71-
active_qmods = _get_active_quant_modifiers(modifiers)
72-
if len(active_qmods) > 1:
73-
raise ValueError(
74-
f"Recipe contains more than one active quantization config "
75-
f"({active_qmods}). These configs may be conflicting, Please modify "
76-
"your recipe to use at most one quantization config"
77-
)
78-
79-
if len(active_qmods) == 1:
80-
quant_modifier = active_qmods[0]
81-
config = quant_modifier.resolve_quantization_config()
82-
if config.requires_calibration_data():
83-
return "sequential"
84-
85-
return "datafree"
86-
87-
88-
def _get_active_quant_modifiers(modifiers: List[Modifier]) -> List[QuantizationMixin]:
89-
return [
90-
modifier
91-
for modifier in modifiers
92-
if isinstance(modifier, QuantizationMixin) and modifier.has_config()
93-
]
57+
def _infer_pipeline(modifiers: List[Modifier]) -> str:
58+
# only in the case of weight-only qmod quantization can we skip calibration
59+
if len(modifiers) == 1 and isinstance(modifiers[0], QuantizationModifier):
60+
config = modifiers[0].resolve_quantization_config()
61+
if not config.requires_calibration_data():
62+
return "datafree"
63+
64+
return "sequential"

0 commit comments

Comments
 (0)