Skip to content

Commit b336fa2

Browse files
committed
deprecate sequential_targets on modifiers
Signed-off-by: Kyle Sayers <[email protected]>
1 parent a389d14 commit b336fa2

File tree

4 files changed

+60
-28
lines changed

4 files changed

+60
-28
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,13 @@ class DatasetArguments(CustomDatasetArguments):
186186
"{module}.{method_name} or {function_name}"
187187
},
188188
)
189+
sequential_targets: Optional[List[str]] = field(
190+
default=None,
191+
metadata={
192+
"help": "List of layer targets for the sequential pipeline. "
193+
"This is typically a single DecoderLayer. "
194+
"Not specifying this argument will cause the sequential pipeline to "
195+
"default to using the `no_split_params` specified by the HF model "
196+
"definition"
197+
},
198+
)

src/llmcompressor/pipelines/registry.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717

1818
__all__ = ["CalibrationPipeline"]
1919

20-
SEQUENTIAL_MODIFIERS = (AWQModifier, GPTQModifier, SparsityModifierBase)
21-
CALIBRATION_MODIFIERS = (SmoothQuantModifier, *SEQUENTIAL_MODIFIERS)
20+
CALIBRATION_MODIFIERS = (
21+
SmoothQuantModifier,
22+
AWQModifier,
23+
GPTQModifier,
24+
SparsityModifierBase,
25+
)
2226

2327

2428
class CalibrationPipeline(ABC, RegistryMixin):

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import inspect
33
from collections import deque
44
from dataclasses import dataclass
5-
from typing import Any, Dict, List, Optional, Set
5+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
66

77
import torch
88
from compressed_tensors import has_offloaded_params
@@ -23,7 +23,10 @@
2323

2424
from .ast_helpers import autowrap_forwards
2525

26-
__all__ = ["trace_subgraphs", "Subgraph", "get_targets_from_modifiers"]
26+
if TYPE_CHECKING:
27+
from llmcompressor.args.dataset_arguments import DatasetArguments
28+
29+
__all__ = ["trace_subgraphs", "Subgraph", "get_sequential_targets"]
2730

2831

2932
@dataclass
@@ -416,44 +419,59 @@ def match_modules(model: Module, target_names: List[str]) -> Set[Module]:
416419
)
417420

418421

419-
def get_targets_from_modifiers(
420-
modifiers: List[Modifier], model: PreTrainedModel
422+
def get_sequential_targets(
423+
modifiers: List[Modifier], model: PreTrainedModel, args: "DatasetArguments"
421424
) -> List[str]:
422425
"""
423-
Infer sequential targets from modifiers list
426+
Infer sequential targets from modifiers list and dataset args
424427
425428
:param model: model being calibrated
426429
:param modifiers: list of modifiers being applied during calibration
430+
:param dataset_args: dataset arguments passed by user
427431
:return: list of sequential targets
428432
"""
429-
# avoid circular import
430-
from llmcompressor.pipelines.registry import SEQUENTIAL_MODIFIERS
431-
432-
sequential_modifiers = [
433-
modifier for modifier in modifiers if isinstance(modifier, SEQUENTIAL_MODIFIERS)
433+
modifier_targets = [
434+
(modifier, modifier.sequential_targets)
435+
for modifier in modifiers
436+
if getattr(modifier, "sequential_targets", None) is not None
434437
]
435438

436-
if len(sequential_modifiers) >= 2:
437-
types = [type(modifier) for modifier in sequential_modifiers]
439+
# deprecation warning
440+
if len(modifier_targets) > 1:
438441
logger.warning(
442+
"Passing sequential targets through modifiers is deprecated, "
443+
"please use `oneshot(sequential_targets=...)`"
444+
)
445+
446+
# cannot infer from multiple modifiers
447+
if len(modifier_targets) >= 2:
448+
types = [type(modifier) for modifier, _ in modifier_targets]
449+
raise ValueError(
439450
"Cannot infer sequential targets from multiple sequential modifiers "
440-
f"({types}). Defaulting to {types[0]}"
451+
f"({types})"
441452
)
442-
elif len(sequential_modifiers) <= 0:
443-
types = [type(modifier) for modifier in modifiers]
444-
raise ValueError(f"Cannot infer sequential targets from list of {types}")
445453

446-
modifier = sequential_modifiers[0]
454+
# resolve single modifier
455+
if len(modifier_targets) == 1:
456+
if args.sequential_targets is not None:
457+
raise ValueError(
458+
f"Got sequential targets from both {type(modifier_targets[0][0])} "
459+
"and dataset arguments `sequential_targets`"
460+
)
461+
462+
sequential_targets = modifier_targets[0][1]
447463

448-
# infer sequential targets
449-
if modifier.sequential_targets is None:
450-
sequential_targets = get_no_split_params(model)
451-
elif isinstance(modifier.sequential_targets, str):
452-
sequential_targets = [modifier.sequential_targets]
464+
# if no modifiers, use data args
453465
else:
454-
sequential_targets = modifier.sequential_targets
466+
sequential_targets = args.sequential_targets # may be `None`
455467

456-
return sequential_targets
468+
# validate and infer
469+
if sequential_targets is None:
470+
return get_no_split_params(model)
471+
elif isinstance(sequential_targets, str):
472+
return [sequential_targets]
473+
else:
474+
return sequential_targets
457475

458476

459477
def add_line_numbers(text: str) -> str:

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from llmcompressor.pipelines.cache import IntermediatesCache
1212
from llmcompressor.pipelines.registry import CalibrationPipeline
1313
from llmcompressor.pipelines.sequential.helpers import (
14-
get_targets_from_modifiers,
14+
get_sequential_targets,
1515
trace_subgraphs,
1616
)
1717
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
@@ -64,7 +64,7 @@ def __call__(
6464

6565
# prepare to trace subgraphs
6666
modifiers = session.get_modifiers()
67-
sequential_targets = get_targets_from_modifiers(modifiers, model)
67+
sequential_targets = get_sequential_targets(modifiers, model, dataset_args)
6868
ignore = dataset_args.tracing_ignore
6969

7070
# trace subgraphs

0 commit comments

Comments
 (0)