8
8
TransformScheme ,
9
9
apply_transform_config ,
10
10
)
11
- from pydantic import Field , field_validator
11
+ from pydantic import Field , ValidationInfo , field_validator
12
12
from transformers import PreTrainedModel
13
13
14
14
from llmcompressor .core import Event , EventType , State
@@ -27,6 +27,37 @@ class SpinquantRotation(str, Enum):
27
27
28
28
29
29
class SpinQuantModifier (Modifier , use_enum_values = True ):
30
+ """
31
+ Implements the transforms according to
32
+ [SpinQuant: LLM quantization with learned rotations](https://arxiv.org/abs/2405.16406) # noqa: E501
33
+
34
+ Transforms (rotations) are extra layers added to a model which reduce the accuracy
35
+ loss induced by quantization. This is achived through "rotating" weights and
36
+ activations into a space with a smaller dynamic range of values, thus decreasing
37
+ the range of scales required for quantization.
38
+
39
+ The SpinQuant authors describe four different rotations which can be applied to a
40
+ model. R1 and R2 are "offline" rotations, meaning that they can be fused into
41
+ existing weights and therefore do not induce runtime cost. R3 and R4 are "online"
42
+ rotations, meaning that they require additional computation at runtime.
43
+
44
+ :param rotations: A list containing the names of rotations to apply to the model.
45
+ Possible rotations include R1, R2, R3, and R4
46
+ :param transform_type: The type of transform to apply to the model.
47
+ `"hadamard"` has the least performance cost but only supports sizes which are
48
+ powers of power of two.
49
+ `"random-matrix"` has more performance cost, but supports a much larger set of
50
+ sizes.
51
+ `"random-matrix"` has the greatest performance cost, but supports any size
52
+ :param randomize: if True, create distinct transforms for each application
53
+ :param learnable: if True, attach gradients to transform weights for training
54
+ :param mappings: Specifies layers within a model to target for transforms.
55
+ A mapping will be inferred if None is provided
56
+ :param norm_mappings: Specifies layers within a model to target for norm fusing.
57
+ A mapping will be inferred if None is provided
58
+ :param transform_config: Optional transform config which overrides `mappings`
59
+ """
60
+
30
61
rotations : List [SpinquantRotation ] = Field (default_factory = lambda : ["R1" , "R2" ])
31
62
transform_type : Literal ["hadamard" , "random-hadamard" , "random-matrix" ] = Field (
32
63
default = "hadamard"
@@ -43,23 +74,29 @@ class SpinQuantModifier(Modifier, use_enum_values=True):
43
74
# also included in recipe serialization
44
75
transform_config : Optional [TransformConfig ] = Field (default = None )
45
76
77
+ @field_validator ("randomize" , "learnable" , mode = "before" )
78
+ def validate_not_implemented (cls , value , info : ValidationInfo ):
79
+ raise NotImplementedError (f"{ info .field_name } is not supported right now" )
80
+
46
81
@field_validator ("rotations" , mode = "before" )
47
82
def validate_rotations (cls , value ):
48
83
if isinstance (value , Iterable ):
49
84
return tuple (v .upper () for v in value )
50
85
return value
51
86
52
87
def on_initialize (self , state : State , ** kwargs ) -> bool :
53
- # TODO: more validation
54
- self .mappings = infer_mapping_from_model (state .model )
55
- self .norm_mappings = infer_norm_mapping_from_model (state .model )
56
-
57
88
if self .transform_config is not None :
58
89
if self .mappings is not None :
59
- raise ValueError ()
90
+ raise ValueError (
91
+ "Please provide either `transform_config` or `mappings` "
92
+ "but not both"
93
+ )
60
94
61
95
return True
62
96
97
+ self .mappings = infer_mapping_from_model (state .model )
98
+ self .norm_mappings = infer_norm_mapping_from_model (state .model )
99
+
63
100
config_groups = {}
64
101
if SpinquantRotation .R1 in self .rotations :
65
102
config_groups ["R1" ] = self ._create_r1_scheme ()
0 commit comments