1
- from typing import Optional , List , Literal , Iterable
1
+ from enum import Enum
2
+ from typing import Iterable , List , Literal , Optional
2
3
3
- from compressed_tensors .transform import TransformConfig , TransformScheme , TransformArgs , apply_transform_config
4
- from pydantic import BaseModel , field_validator , Field
4
+ from compressed_tensors .transform import (
5
+ TransformArgs ,
6
+ TransformConfig ,
7
+ TransformScheme ,
8
+ apply_transform_config ,
9
+ )
10
+ from pydantic import BaseModel , Field , field_validator
11
+ from transformers import PreTrainedModel
5
12
6
13
from llmcompressor .core import Event , EventType , State
7
14
from llmcompressor .modeling import fuse_norm_linears
8
15
from llmcompressor .modifiers import Modifier
9
- from enum import Enum
10
-
11
- from transformers import PreTrainedModel
12
16
13
17
14
18
class SpinQuantMappings (BaseModel ):
@@ -29,9 +33,10 @@ class SpinQuantMappings(BaseModel):
29
33
def cast_to_list (cls , value ):
30
34
if isinstance (value , str ):
31
35
return [value ]
32
-
36
+
33
37
return value
34
-
38
+
39
+
35
40
class NormMapping (BaseModel ):
36
41
norm : str
37
42
linears : List [str ]
@@ -40,22 +45,18 @@ class NormMapping(BaseModel):
40
45
def cast_to_list (cls , value ):
41
46
if isinstance (value , str ):
42
47
return [value ]
43
-
44
- return value
45
48
49
+ return value
46
50
47
51
48
52
llama_spinquant = SpinQuantMappings (
49
53
embedding = "re:.*embed_tokens$" ,
50
-
51
54
attn_q = "re:.*q_proj$" ,
52
55
attn_k = "re:.*k_proj$" ,
53
56
attn_v = "re:.*v_proj$" ,
54
57
attn_o = "re:.*o_proj$" ,
55
-
56
58
mlp_in = ["re:.*up_proj$" , "re:.*gate_proj$" ],
57
59
mlp_out = "re:.*down_proj$" ,
58
-
59
60
lm_head = "lm_head" ,
60
61
)
61
62
@@ -67,25 +68,31 @@ def cast_to_list(cls, value):
67
68
NormMapping (
68
69
norm = "re:.*post_attention_layernorm$" ,
69
70
linears = ["re:.*up_proj$" , "re:.*gate_proj$" ],
70
- )
71
+ ),
71
72
]
72
73
74
+
73
75
class SpinquantRotation (Enum ):
74
76
R1 = "R1"
75
77
R2 = "R2"
76
78
R3 = "R3"
77
79
R4 = "R4"
78
80
81
+
79
82
class SpinQuantModifier (Modifier ):
80
83
rotations : Iterable [SpinquantRotation ] = ("R1" , "R2" )
81
- transform_type : Literal ["hadamard" , "random-hadamard" , "random-matrix" ] = Field (default = "hadamard" )
84
+ transform_type : Literal ["hadamard" , "random-hadamard" , "random-matrix" ] = Field (
85
+ default = "hadamard"
86
+ )
82
87
randomize : bool = Field (default = False )
83
88
learnable : bool = Field (default = False )
84
89
85
90
mappings : Optional [SpinQuantMappings ] = None
86
91
norm_mappings : Optional [List [NormMapping ]] = None
87
-
88
- transform_config : Optional [TransformConfig ] = None # optional override for more fine-grained control
92
+
93
+ transform_config : Optional [TransformConfig ] = (
94
+ None # optional override for more fine-grained control
95
+ )
89
96
90
97
@field_validator ("rotations" , mode = "before" )
91
98
def validate_rotations (cls , value ):
@@ -101,7 +108,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
101
108
if self .transform_config is not None :
102
109
if self .mappings is not None :
103
110
raise ValueError ()
104
-
111
+
105
112
return True
106
113
107
114
config_groups = {}
@@ -129,6 +136,7 @@ def on_start(self, state: State, event: Event, **kwargs):
129
136
# Embedding fusion
130
137
# theoretically, doesn't do anything. Doesn't seem to help model sanity either
131
138
from compressed_tensors import update_offload_parameter
139
+
132
140
for W in [state .model .model .embed_tokens ]:
133
141
W_ = W .weight .data .double ()
134
142
W .weight .data = (W_ - W_ .mean (dim = - 1 , keepdim = True )).to (W .weight .data .dtype )
@@ -138,16 +146,24 @@ def on_start(self, state: State, event: Event, **kwargs):
138
146
# TODO: use norm mappings
139
147
# layer norm fusion
140
148
for layer in state .model .model .layers :
141
- fuse_norm_linears (layer .input_layernorm , (layer .self_attn .q_proj , layer .self_attn .k_proj , layer .self_attn .v_proj ))
142
- fuse_norm_linears (layer .post_attention_layernorm , (layer .mlp .gate_proj , layer .mlp .up_proj ))
149
+ fuse_norm_linears (
150
+ layer .input_layernorm ,
151
+ (
152
+ layer .self_attn .q_proj ,
153
+ layer .self_attn .k_proj ,
154
+ layer .self_attn .v_proj ,
155
+ ),
156
+ )
157
+ fuse_norm_linears (
158
+ layer .post_attention_layernorm , (layer .mlp .gate_proj , layer .mlp .up_proj )
159
+ )
160
+
161
+ fuse_norm_linears (state .model .model .norm , (state .model .lm_head ,))
143
162
144
163
# needs to happen after the model has been hooked to execute on the GPU
145
164
# otherwise we're applying weight transforms on CPU
146
165
apply_transform_config (state .model , self .transform_config )
147
166
148
-
149
-
150
-
151
167
def on_event (self , state : State , event : Event , ** kwargs ):
152
168
if event .type_ == EventType .CALIBRATION_EPOCH_START :
153
169
if not self .started_ :
@@ -169,7 +185,6 @@ def on_finalize(self, state: State, **kwargs) -> bool:
169
185
170
186
return True
171
187
172
-
173
188
def _create_r1_scheme (self ) -> TransformScheme :
174
189
return TransformScheme (
175
190
type = self .transform_type ,
@@ -190,14 +205,14 @@ def _create_r1_scheme(self) -> TransformScheme:
190
205
self .mappings .attn_k ,
191
206
self .mappings .attn_v ,
192
207
* self .mappings .mlp_in ,
193
- self .mappings .lm_head
208
+ self .mappings .lm_head ,
194
209
],
195
210
location = "weight_input" ,
196
211
inverse = True ,
197
212
),
198
- ]
213
+ ],
199
214
)
200
-
215
+
201
216
def _create_r2_scheme (self , model : PreTrainedModel ) -> TransformScheme :
202
217
config = model .config
203
218
@@ -207,7 +222,7 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
207
222
head_dim = config .hidden_size // config .num_attention_heads
208
223
else :
209
224
raise NotImplementedError ()
210
-
225
+
211
226
return TransformScheme (
212
227
type = self .transform_type ,
213
228
randomize = self .randomize ,
@@ -223,10 +238,8 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
223
238
],
224
239
)
225
240
226
-
227
241
def _create_r3_scheme (self ) -> TransformScheme :
228
242
raise NotImplementedError ()
229
243
230
-
231
244
def _create_r4_scheme (self ) -> TransformScheme :
232
- raise NotImplementedError ()
245
+ raise NotImplementedError ()
0 commit comments