Skip to content

Commit fce83be

Browse files
committed
use norm mappings
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 0e9af7b commit fce83be

File tree

2 files changed

+54
-35
lines changed

2 files changed

+54
-35
lines changed

src/llmcompressor/modeling/fuse.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,24 @@
88
)
99
from transformers.models.llama.modeling_llama import LlamaRMSNorm
1010

11-
__all__ = ["fuse_norm_linears"]
11+
__all__ = ["normalize_embedding", "fuse_norm_linears"]
12+
13+
14+
PRECISION = torch.float64
15+
16+
17+
def normalize_embedding(embedding: torch.nn.Module):
18+
if isinstance(embedding, (torch.nn.Embedding)):
19+
with align_module_device(embedding):
20+
weight_dtype = embedding.weight.dtype
21+
weight = embedding.weight.to(PRECISION)
22+
new_weight = weight - weight.mean(dim=-1, keepdim=True)
23+
new_weight = new_weight.to(weight_dtype)
24+
25+
update_offload_parameter(embedding, "weight", new_weight)
26+
27+
else:
28+
raise ValueError(f"Cannot normalize embedding of type {type(embedding)}")
1229

1330

1431
def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
@@ -29,11 +46,7 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear])
2946
linear, exec_device
3047
):
3148
weight_dtype = linear.weight.dtype
32-
33-
new_weight = linear.weight.to(torch.float64) * norm.weight.to(
34-
torch.float64
35-
)
36-
49+
new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
3750
new_weight = new_weight.to(weight_dtype)
3851

3952
update_offload_parameter(linear, "weight", new_weight)

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from enum import Enum
22
from typing import Iterable, List, Literal, Optional
33

4+
from compressed_tensors import match_named_modules, is_match
45
from compressed_tensors.transform import (
56
TransformArgs,
67
TransformConfig,
@@ -11,7 +12,7 @@
1112
from transformers import PreTrainedModel
1213

1314
from llmcompressor.core import Event, EventType, State
14-
from llmcompressor.modeling import fuse_norm_linears
15+
from llmcompressor.modeling import normalize_embedding, fuse_norm_linears
1516
from llmcompressor.modifiers import Modifier
1617

1718

@@ -69,6 +70,10 @@ def cast_to_list(cls, value):
6970
norm="re:.*post_attention_layernorm$",
7071
linears=["re:.*up_proj$", "re:.*gate_proj$"],
7172
),
73+
NormMapping(
74+
norm="model.norm",
75+
linears=["lm_head"],
76+
),
7277
]
7378

7479

@@ -132,36 +137,10 @@ def on_initialize(self, state: State, **kwargs) -> bool:
132137
def on_start(self, state: State, event: Event, **kwargs):
133138
self.started_ = True
134139

135-
# TODO: use norm mappings
136-
# Embedding fusion
137-
# theoretically, doesn't do anything. Doesn't seem to help model sanity either
138-
from compressed_tensors import update_offload_parameter
139-
140-
for W in [state.model.model.embed_tokens]:
141-
W_ = W.weight.data.double()
142-
W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)
143-
144-
update_offload_parameter(state.model.model.embed_tokens, "weight", W.weight)
145-
146-
# TODO: use norm mappings
147-
# layer norm fusion
148-
for layer in state.model.model.layers:
149-
fuse_norm_linears(
150-
layer.input_layernorm,
151-
(
152-
layer.self_attn.q_proj,
153-
layer.self_attn.k_proj,
154-
layer.self_attn.v_proj,
155-
),
156-
)
157-
fuse_norm_linears(
158-
layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)
159-
)
160-
161-
fuse_norm_linears(state.model.model.norm, (state.model.lm_head,))
162-
163140
# needs to happen after the model has been hooked to execute on the GPU
164141
# otherwise we're applying weight transforms on CPU
142+
self._prenormalize_embeddings(state.model)
143+
self._fuse_norms(state.model)
165144
apply_transform_config(state.model, self.transform_config)
166145

167146
def on_event(self, state: State, event: Event, **kwargs):
@@ -185,6 +164,33 @@ def on_finalize(self, state: State, **kwargs) -> bool:
185164

186165
return True
187166

167+
def _prenormalize_embeddings(self, model: PreTrainedModel):
168+
for _, embedding in match_named_modules(
169+
model, [self.mappings.embedding], warn_on_fail=True
170+
):
171+
normalize_embedding(embedding)
172+
173+
def _fuse_norms(self, model: PreTrainedModel):
174+
for mapping in self.norm_mappings:
175+
targets = (mapping.norm, *mapping.linears)
176+
matches = dict()
177+
178+
for name, module in model.named_modules():
179+
# match until we get a full set
180+
for target in targets:
181+
if is_match(name, module, target):
182+
if target in matches:
183+
raise ValueError("Cannot match twice")
184+
matches[target] = module
185+
186+
# once we have a full set, fuse and reset
187+
if all(target in matches for target in targets):
188+
fuse_norm_linears(
189+
matches[mapping.norm],
190+
(matches[target] for target in mapping.linears),
191+
)
192+
matches = dict()
193+
188194
def _create_r1_scheme(self) -> TransformScheme:
189195
return TransformScheme(
190196
type=self.transform_type,

0 commit comments

Comments
 (0)