@@ -119,16 +119,17 @@ def on_initialize(self, state: State, **kwargs) -> bool:
119
119
120
120
self .mappings = infer_mapping_from_model (state .model )
121
121
self .norm_mappings = infer_norm_mapping_from_model (state .model )
122
+ head_dim = self ._infer_head_dim (state .model )
122
123
123
124
config_groups = {}
124
125
if SpinquantRotation .R1 in self .rotations :
125
126
config_groups ["R1" ] = self ._create_r1_scheme ()
126
127
127
128
if SpinquantRotation .R2 in self .rotations :
128
- config_groups ["R2" ] = self ._create_r2_scheme (state . model )
129
+ config_groups ["R2" ] = self ._create_r2_scheme (head_dim )
129
130
130
131
if SpinquantRotation .R3 in self .rotations :
131
- config_groups ["R3" ] = self ._create_r3_scheme ()
132
+ config_groups ["R3" ] = self ._create_r3_scheme (head_dim )
132
133
133
134
if SpinquantRotation .R4 in self .rotations :
134
135
config_groups ["R4" ] = self ._create_r4_scheme ()
@@ -209,16 +210,7 @@ def _create_r1_scheme(self) -> TransformScheme:
209
210
],
210
211
)
211
212
212
- def _create_r2_scheme (self , model : PreTrainedModel ) -> TransformScheme :
213
- config = model .config
214
-
215
- if hasattr (config , "head_dim" ):
216
- head_dim = config .head_dim
217
- elif hasattr (config , "hidden_size" ) and hasattr (config , "num_attention_heads" ):
218
- head_dim = config .hidden_size // config .num_attention_heads
219
- else :
220
- raise NotImplementedError ()
221
-
213
+ def _create_r2_scheme (self , head_dim : int ) -> TransformScheme :
222
214
return TransformScheme (
223
215
type = self .transform_type ,
224
216
randomize = self .randomize ,
@@ -235,9 +227,23 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
235
227
],
236
228
)
237
229
238
- def _create_r3_scheme (self ) -> TransformScheme :
239
- raise NotImplementedError (
240
- "SpinQuant R3 rotations will be added in a future release"
230
+ def _create_r3_scheme (self , head_dim : int ) -> TransformScheme :
231
+ return TransformScheme (
232
+ type = self .transform_type ,
233
+ randomize = self .randomize ,
234
+ requires_grad = self .learnable ,
235
+ precision = self .precision ,
236
+ head_dim = head_dim ,
237
+ apply = [
238
+ TransformArgs (
239
+ targets = [self .mappings .attn ],
240
+ location = "q_attn" ,
241
+ ),
242
+ TransformArgs (
243
+ targets = [self .mappings .attn ],
244
+ location = "k_cache" ,
245
+ ),
246
+ ],
241
247
)
242
248
243
249
def _create_r4_scheme (self ) -> TransformScheme :
@@ -258,3 +264,13 @@ def _create_r4_scheme(self) -> TransformScheme:
258
264
),
259
265
],
260
266
)
267
+
268
+ def _infer_head_dim (self , model : PreTrainedModel ) -> int :
269
+ config = model .config
270
+
271
+ if hasattr (config , "head_dim" ):
272
+ return config .head_dim
273
+ elif hasattr (config , "hidden_size" ) and hasattr (config , "num_attention_heads" ):
274
+ return config .hidden_size // config .num_attention_heads
275
+ else :
276
+ raise NotImplementedError ()
0 commit comments