1
+ from typing import Optional , List , Literal
2
+
3
+ from compressed_tensors .transform import TransformConfig , TransformScheme , TransformArgs , apply_transform_config
4
+ from pydantic import BaseModel , field_validator , Field
5
+
6
+ from llmcompressor .core import Event , EventType , State
7
+ from llmcompressor .modeling import fuse_norm_linears
8
+ from llmcompressor .modifiers import Modifier
9
+ from enum import Enum
10
+
11
+ from transformers import PreTrainedModel
12
+
13
+
14
+ class SpinQuantMappings (BaseModel ):
15
+ embedding : str
16
+
17
+ attn_q : str
18
+ attn_k : str
19
+ attn_v : str
20
+ attn_o : str
21
+ attn_head_dim : Optional [int ] = Field (default = None )
22
+
23
+ mlp_in : List [str ] # up_proj, gate_proj
24
+ mlp_out : List [str ] # down_proj
25
+
26
+ lm_head : str
27
+
28
+ @field_validator ("mlp_in" , "mlp_out" , mode = "before" )
29
+ def cast_to_list (cls , value ):
30
+ if isinstance (value , str ):
31
+ return [value ]
32
+
33
+ return value
34
+
35
+ class NormMapping (BaseModel ):
36
+ norm : str
37
+ linears : List [str ]
38
+
39
+ @field_validator ("linears" , mode = "before" )
40
+ def cast_to_list (cls , value ):
41
+ if isinstance (value , str ):
42
+ return [value ]
43
+
44
+ return value
45
+
46
+
47
+
48
+ llama_spinquant = SpinQuantMappings (
49
+ embedding = "re:.*embed_tokens$" ,
50
+
51
+ attn_q = "re:.*q_proj$" ,
52
+ attn_k = "re:.*k_proj$" ,
53
+ attn_v = "re:.*v_proj$" ,
54
+ attn_o = "re:.*o_proj$" ,
55
+
56
+ mlp_in = ["re:.*up_proj$" , "re:.*gate_proj$" ],
57
+ mlp_out = "re:.*down_proj$" ,
58
+
59
+ lm_head = "lm_head" ,
60
+ )
61
+
62
+ llama_norm_mappings = [
63
+ NormMapping (
64
+ norm = "re:.*input_layernorm$" ,
65
+ linears = ["re:.*q_proj$" , "re:.*k_proj$" , "re:.*v_proj$" ],
66
+ ),
67
+ NormMapping (
68
+ norm = "re:.*post_attention_layernorm$" ,
69
+ linears = ["re:.*up_proj$" , "re:.*gate_proj$" ],
70
+ )
71
+ ]
72
+
73
+ class SpinquantRotation (Enum ):
74
+ R1 = "R1"
75
+ R2 = "R2"
76
+ R3 = "R3"
77
+ R4 = "R4"
78
+
79
+ class SpinQuantModifier (Modifier ):
80
+ rotations : List [SpinquantRotation ] = Field (default_factory = lambda : ["R1" , "R2" ])
81
+
82
+ transform_type : Literal ["hadamard" , "random-hadamard" , "random-matrix" ] = Field (default = "hadamard" )
83
+ randomize : bool = Field (default = False )
84
+ learnable : bool = Field (default = False )
85
+
86
+ mappings : Optional [SpinQuantMappings ] = None
87
+ norm_mappings : Optional [List [NormMapping ]] = None
88
+
89
+ transform_config : Optional [TransformConfig ] = None # optional override for more fine-grained control
90
+
91
+ def on_initialize (self , state : State , ** kwargs ) -> bool :
92
+ # HARDCODE
93
+ self .mappings = llama_spinquant
94
+ self .norm_mappings = llama_norm_mappings
95
+
96
+ if self .transform_config is not None :
97
+ if self .mappings is not None :
98
+ raise ValueError ()
99
+
100
+ return True
101
+
102
+ config_groups = {}
103
+ for rotation in self .rotations :
104
+ if rotation == SpinquantRotation .R1 :
105
+ config_groups ["R1" ] = self ._create_r1_scheme ()
106
+
107
+ if rotation == SpinquantRotation .R2 :
108
+ config_groups ["R2" ] = self ._create_r2_scheme (state .model )
109
+
110
+ if rotation == SpinquantRotation .R3 :
111
+ config_groups ["R3" ] = self ._create_r3_scheme ()
112
+
113
+ if rotation == SpinquantRotation .R4 :
114
+ config_groups ["R4" ] = self ._create_r4_scheme ()
115
+
116
+ self .transform_config = TransformConfig (config_groups = config_groups )
117
+
118
+ return True
119
+
120
+ def on_start (self , state : State , event : Event , ** kwargs ):
121
+ self .started_ = True
122
+
123
+ for layer in state .model .model .layers :
124
+ fuse_norm_linears (layer .input_layernorm , (layer .self_attn .q_proj , layer .self_attn .k_proj , layer .self_attn .v_proj ))
125
+ fuse_norm_linears (layer .post_attention_layernorm , (layer .mlp .gate_proj , layer .mlp .up_proj ))
126
+
127
+ # needs to happen after the model has been hooked to execute on the GPU
128
+ # otherwise we're applying weight transforms on CPU
129
+ apply_transform_config (state .model , self .transform_config )
130
+
131
+
132
+
133
+
134
+ def on_event (self , state : State , event : Event , ** kwargs ):
135
+ if event .type_ == EventType .CALIBRATION_EPOCH_START :
136
+ if not self .started_ :
137
+ self .on_start (state , None )
138
+
139
+ elif event .type_ == EventType .SEQUENTIAL_EPOCH_END :
140
+ pass
141
+
142
+ elif event .type_ == EventType .CALIBRATION_EPOCH_END :
143
+ if not self .ended_ :
144
+ self .on_end (state , None )
145
+
146
+ def on_end (self , state : State , event : Event , ** kwargs ):
147
+ self .ended_ = True
148
+
149
+ def on_finalize (self , state : State , ** kwargs ) -> bool :
150
+ if not self .ended_ :
151
+ self .on_end (state , None )
152
+
153
+ return True
154
+
155
+
156
+ def _create_r1_scheme (self ) -> TransformScheme :
157
+ return TransformScheme (
158
+ type = self .transform_type ,
159
+ randomize = self .randomize ,
160
+ requires_grad = self .learnable ,
161
+ apply = [
162
+ TransformArgs (
163
+ targets = [
164
+ self .mappings .embedding ,
165
+ self .mappings .attn_o ,
166
+ * self .mappings .mlp_out ,
167
+ ],
168
+ location = "weight_output" ,
169
+ ),
170
+ TransformArgs (
171
+ targets = [
172
+ self .mappings .attn_q ,
173
+ self .mappings .attn_k ,
174
+ self .mappings .attn_v ,
175
+ * self .mappings .mlp_in ,
176
+ self .mappings .lm_head
177
+ ],
178
+ location = "weight_input" ,
179
+ inverse = True ,
180
+ ),
181
+ ]
182
+ )
183
+
184
+ def _create_r2_scheme (self , model : PreTrainedModel ) -> TransformScheme :
185
+ config = model .config
186
+
187
+ if hasattr (config , "head_dim" ):
188
+ head_dim = config .head_dim
189
+ elif hasattr (config , "hidden_size" ) and hasattr (config , "num_attention_heads" ):
190
+ head_dim = config .hidden_size // config .num_attention_heads
191
+ else :
192
+ raise NotImplementedError ()
193
+
194
+ return TransformScheme (
195
+ type = self .transform_type ,
196
+ randomize = self .randomize ,
197
+ requires_grad = self .learnable ,
198
+ head_dim = head_dim ,
199
+ apply = [
200
+ TransformArgs (targets = [self .mappings .attn_v ], location = "weight_output" ),
201
+ TransformArgs (
202
+ targets = [self .mappings .attn_o ],
203
+ location = "weight_input" ,
204
+ inverse = True ,
205
+ ),
206
+ ],
207
+ )
208
+
209
+
210
+ def _create_r3_scheme (self ) -> TransformScheme :
211
+ raise NotImplementedError ()
212
+
213
+
214
+ def _create_r4_scheme (self ) -> TransformScheme :
215
+ raise NotImplementedError ()
0 commit comments