Skip to content

Commit 7a52b71

Browse files
committed
docstring
Signed-off-by: Kyle Sayers <[email protected]>
1 parent b44ac81 commit 7a52b71

File tree

1 file changed

+43
-6
lines changed
  • src/llmcompressor/modifiers/transform/spinquant

1 file changed

+43
-6
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
TransformScheme,
99
apply_transform_config,
1010
)
11-
from pydantic import Field, field_validator
11+
from pydantic import Field, ValidationInfo, field_validator
1212
from transformers import PreTrainedModel
1313

1414
from llmcompressor.core import Event, EventType, State
@@ -27,6 +27,37 @@ class SpinquantRotation(str, Enum):
2727

2828

2929
class SpinQuantModifier(Modifier, use_enum_values=True):
30+
"""
31+
Implements the transforms according to
32+
[SpinQuant: LLM quantization with learned rotations](https://arxiv.org/abs/2405.16406) # noqa: E501
33+
34+
Transforms (rotations) are extra layers added to a model which reduce the accuracy
35+
loss induced by quantization. This is achived through "rotating" weights and
36+
activations into a space with a smaller dynamic range of values, thus decreasing
37+
the range of scales required for quantization.
38+
39+
The SpinQuant authors describe four different rotations which can be applied to a
40+
model. R1 and R2 are "offline" rotations, meaning that they can be fused into
41+
existing weights and therefore do not induce runtime cost. R3 and R4 are "online"
42+
rotations, meaning that they require additional computation at runtime.
43+
44+
:param rotations: A list containing the names of rotations to apply to the model.
45+
Possible rotations include R1, R2, R3, and R4
46+
:param transform_type: The type of transform to apply to the model.
47+
`"hadamard"` has the least performance cost but only supports sizes which are
48+
powers of power of two.
49+
`"random-matrix"` has more performance cost, but supports a much larger set of
50+
sizes.
51+
`"random-matrix"` has the greatest performance cost, but supports any size
52+
:param randomize: if True, create distinct transforms for each application
53+
:param learnable: if True, attach gradients to transform weights for training
54+
:param mappings: Specifies layers within a model to target for transforms.
55+
A mapping will be inferred if None is provided
56+
:param norm_mappings: Specifies layers within a model to target for norm fusing.
57+
A mapping will be inferred if None is provided
58+
:param transform_config: Optional transform config which overrides `mappings`
59+
"""
60+
3061
rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"])
3162
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
3263
default="hadamard"
@@ -43,23 +74,29 @@ class SpinQuantModifier(Modifier, use_enum_values=True):
4374
# also included in recipe serialization
4475
transform_config: Optional[TransformConfig] = Field(default=None)
4576

77+
@field_validator("randomize", "learnable", mode="before")
78+
def validate_not_implemented(cls, value, info: ValidationInfo):
79+
raise NotImplementedError(f"{info.field_name} is not supported right now")
80+
4681
@field_validator("rotations", mode="before")
4782
def validate_rotations(cls, value):
4883
if isinstance(value, Iterable):
4984
return tuple(v.upper() for v in value)
5085
return value
5186

5287
def on_initialize(self, state: State, **kwargs) -> bool:
53-
# TODO: more validation
54-
self.mappings = infer_mapping_from_model(state.model)
55-
self.norm_mappings = infer_norm_mapping_from_model(state.model)
56-
5788
if self.transform_config is not None:
5889
if self.mappings is not None:
59-
raise ValueError()
90+
raise ValueError(
91+
"Please provide either `transform_config` or `mappings` "
92+
"but not both"
93+
)
6094

6195
return True
6296

97+
self.mappings = infer_mapping_from_model(state.model)
98+
self.norm_mappings = infer_norm_mapping_from_model(state.model)
99+
63100
config_groups = {}
64101
if SpinquantRotation.R1 in self.rotations:
65102
config_groups["R1"] = self._create_r1_scheme()

0 commit comments

Comments
 (0)