Skip to content

Commit 5daa2d5

Browse files
committed
embedding fusion
Signed-off-by: Kyle Sayers <[email protected]>
1 parent fdb64b5 commit 5daa2d5

File tree

1 file changed

+12
-0
lines changed
  • src/llmcompressor/modifiers/transform/spinquant

1 file changed

+12
-0
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,18 @@ def on_initialize(self, state: State, **kwargs) -> bool:
125125
def on_start(self, state: State, event: Event, **kwargs):
126126
self.started_ = True
127127

128+
# TODO: use norm mappings
129+
# Embedding fusion
130+
# theoretically, doesn't do anything. Doesn't seem to help model sanity either
131+
from compressed_tensors import update_offload_parameter
132+
for W in [state.model.model.embed_tokens]:
133+
W_ = W.weight.data.double()
134+
W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)
135+
136+
update_offload_parameter(state.model.model.embed_tokens, "weight", W.weight)
137+
138+
# TODO: use norm mappings
139+
# layer norm fusion
128140
for layer in state.model.model.layers:
129141
fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj))
130142
fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj))

0 commit comments

Comments
 (0)