Skip to content

Commit 53ea307

Browse files
committed
standardize, make modifier serializable
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 0cf0188 commit 53ea307

File tree

5 files changed

+13
-105
lines changed

5 files changed

+13
-105
lines changed

examples/transform/llama3_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def tokenize(sample):
5858
# * apply spinquant transforms to model in order to make quantization easier
5959
# * quantize the weights to 4 bit with GPTQ with a group size 128
6060
recipe = [
61-
SpinQuantModifier(rotations=["R1", "R2"], transform_type="random-hadamard"),
61+
SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"),
6262
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
6363
]
6464

src/llmcompressor/modifiers/transform/quip/base.py

Whitespace-only changes.

src/llmcompressor/modifiers/transform/quip/template.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
from .norm_mappings import NORM_MAPPING_REGISTRY, NormMapping
2020

2121

22-
class SpinquantRotation(Enum):
22+
class SpinquantRotation(str, Enum):
2323
R1 = "R1"
2424
R2 = "R2"
2525
R3 = "R3"
2626
R4 = "R4"
2727

2828

29-
class SpinQuantModifier(Modifier):
30-
rotations: Iterable[SpinquantRotation] = ("R1", "R2")
29+
class SpinQuantModifier(Modifier, use_enum_values=True):
30+
rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"])
3131
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
3232
default="hadamard"
3333
)
@@ -38,11 +38,12 @@ class SpinQuantModifier(Modifier):
3838
# override spinquant mappings with transform_config without overriding norms
3939
# we can combine these mappings, but it requires some more validation logic
4040
# maybe there's a reason to keep if other modifiers want norm fusing, idk
41-
mappings: Optional[SpinQuantMappings] = None
42-
norm_mappings: Optional[List[NormMapping]] = None
41+
mappings: Optional[SpinQuantMappings] = Field(default=None, exclude=True)
42+
norm_mappings: Optional[List[NormMapping]] = Field(default=None, exclude=True)
4343

4444
# optional override for more fine-grained control
45-
transform_config: Optional[TransformConfig] = None
45+
# also included in recipe serialization
46+
transform_config: Optional[TransformConfig] = Field(default=None)
4647

4748
@field_validator("rotations", mode="before")
4849
def validate_rotations(cls, value):

src/llmcompressor/pipelines/registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from llmcompressor.modifiers import Modifier
1010
from llmcompressor.modifiers.quantization import QuantizationModifier
11+
from llmcompressor.modifiers.transform import SpinQuantModifier
1112

1213
if TYPE_CHECKING:
1314
from llmcompressor.args.dataset_arguments import DatasetArguments
@@ -60,5 +61,9 @@ def _infer_pipeline(modifiers: List[Modifier]) -> str:
6061
config = modifiers[0].resolve_quantization_config()
6162
if not config.requires_calibration_data():
6263
return "datafree"
64+
65+
# TODO: Remove hardcode
66+
if len(modifiers) == 1 and isinstance(modifiers[0], SpinQuantModifier):
67+
return "datafree"
6368

6469
return "sequential"

0 commit comments

Comments
 (0)