1
1
from enum import Enum
2
2
from typing import Iterable , List , Literal , Optional
3
3
4
+ from compressed_tensors import match_named_modules , is_match
4
5
from compressed_tensors .transform import (
5
6
TransformArgs ,
6
7
TransformConfig ,
11
12
from transformers import PreTrainedModel
12
13
13
14
from llmcompressor .core import Event , EventType , State
14
- from llmcompressor .modeling import fuse_norm_linears
15
+ from llmcompressor .modeling import normalize_embedding , fuse_norm_linears
15
16
from llmcompressor .modifiers import Modifier
16
17
17
18
@@ -69,6 +70,10 @@ def cast_to_list(cls, value):
69
70
norm = "re:.*post_attention_layernorm$" ,
70
71
linears = ["re:.*up_proj$" , "re:.*gate_proj$" ],
71
72
),
73
+ NormMapping (
74
+ norm = "model.norm" ,
75
+ linears = ["lm_head" ],
76
+ ),
72
77
]
73
78
74
79
@@ -132,36 +137,10 @@ def on_initialize(self, state: State, **kwargs) -> bool:
132
137
def on_start (self , state : State , event : Event , ** kwargs ):
133
138
self .started_ = True
134
139
135
- # TODO: use norm mappings
136
- # Embedding fusion
137
- # theoretically, doesn't do anything. Doesn't seem to help model sanity either
138
- from compressed_tensors import update_offload_parameter
139
-
140
- for W in [state .model .model .embed_tokens ]:
141
- W_ = W .weight .data .double ()
142
- W .weight .data = (W_ - W_ .mean (dim = - 1 , keepdim = True )).to (W .weight .data .dtype )
143
-
144
- update_offload_parameter (state .model .model .embed_tokens , "weight" , W .weight )
145
-
146
- # TODO: use norm mappings
147
- # layer norm fusion
148
- for layer in state .model .model .layers :
149
- fuse_norm_linears (
150
- layer .input_layernorm ,
151
- (
152
- layer .self_attn .q_proj ,
153
- layer .self_attn .k_proj ,
154
- layer .self_attn .v_proj ,
155
- ),
156
- )
157
- fuse_norm_linears (
158
- layer .post_attention_layernorm , (layer .mlp .gate_proj , layer .mlp .up_proj )
159
- )
160
-
161
- fuse_norm_linears (state .model .model .norm , (state .model .lm_head ,))
162
-
163
140
# needs to happen after the model has been hooked to execute on the GPU
164
141
# otherwise we're applying weight transforms on CPU
142
+ self ._prenormalize_embeddings (state .model )
143
+ self ._fuse_norms (state .model )
165
144
apply_transform_config (state .model , self .transform_config )
166
145
167
146
def on_event (self , state : State , event : Event , ** kwargs ):
@@ -185,6 +164,33 @@ def on_finalize(self, state: State, **kwargs) -> bool:
185
164
186
165
return True
187
166
167
+ def _prenormalize_embeddings (self , model : PreTrainedModel ):
168
+ for _ , embedding in match_named_modules (
169
+ model , [self .mappings .embedding ], warn_on_fail = True
170
+ ):
171
+ normalize_embedding (embedding )
172
+
173
+ def _fuse_norms (self , model : PreTrainedModel ):
174
+ for mapping in self .norm_mappings :
175
+ targets = (mapping .norm , * mapping .linears )
176
+ matches = dict ()
177
+
178
+ for name , module in model .named_modules ():
179
+ # match until we get a full set
180
+ for target in targets :
181
+ if is_match (name , module , target ):
182
+ if target in matches :
183
+ raise ValueError ("Cannot match twice" )
184
+ matches [target ] = module
185
+
186
+ # once we have a full set, fuse and reset
187
+ if all (target in matches for target in targets ):
188
+ fuse_norm_linears (
189
+ matches [mapping .norm ],
190
+ (matches [target ] for target in mapping .linears ),
191
+ )
192
+ matches = dict ()
193
+
188
194
def _create_r1_scheme (self ) -> TransformScheme :
189
195
return TransformScheme (
190
196
type = self .transform_type ,
0 commit comments