Skip to content

Commit 0e9af7b

Browse files
committed
add missing norm fusion
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 5daa2d5 commit 0e9af7b

File tree

5 files changed

+62
-44
lines changed

5 files changed

+62
-44
lines changed

examples/transform/spinquant_dummy.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
from datasets import load_dataset
2-
from transformers import AutoModelForCausalLM, AutoTokenizer
31
import torch
42
from compressed_tensors.utils import update_parameter_data
3+
from datasets import load_dataset
4+
from transformers import AutoModelForCausalLM, AutoTokenizer
5+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
6+
57
from llmcompressor import oneshot
68
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
79
from llmcompressor.modifiers.transform import SpinQuantModifier
810
from llmcompressor.utils import dispatch_for_generation
9-
from transformers.models.llama.modeling_llama import (
10-
LlamaRMSNorm,
11-
)
1211

1312
hidden_dim = intermediate_dim = 64
1413
up_dim = 128
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# flake8: noqa
22

3+
from .fuse import *
34
from .prepare import *
4-
from .fuse import *

src/llmcompressor/modeling/fuse.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from typing import Iterable
22

33
import torch
4-
from compressed_tensors import get_execution_device, align_module_device, update_offload_parameter
5-
4+
from compressed_tensors import (
5+
align_module_device,
6+
get_execution_device,
7+
update_offload_parameter,
8+
)
69
from transformers.models.llama.modeling_llama import LlamaRMSNorm
710

811
__all__ = ["fuse_norm_linears"]
@@ -22,14 +25,17 @@ def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear])
2225
for linear in linears:
2326
# NOTE: spinquant does this op in float64
2427
exec_device = get_execution_device(norm)
25-
with align_module_device(norm, exec_device), align_module_device(linear, exec_device):
26-
28+
with align_module_device(norm, exec_device), align_module_device(
29+
linear, exec_device
30+
):
2731
weight_dtype = linear.weight.dtype
2832

29-
new_weight = linear.weight.to(torch.float64) * norm.weight.to(torch.float64)
33+
new_weight = linear.weight.to(torch.float64) * norm.weight.to(
34+
torch.float64
35+
)
3036

3137
new_weight = new_weight.to(weight_dtype)
32-
38+
3339
update_offload_parameter(linear, "weight", new_weight)
3440

3541
update_offload_parameter(norm, "weight", torch.ones_like(norm.weight))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .base import *
1+
from .base import *

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1-
from typing import Optional, List, Literal, Iterable
1+
from enum import Enum
2+
from typing import Iterable, List, Literal, Optional
23

3-
from compressed_tensors.transform import TransformConfig, TransformScheme, TransformArgs, apply_transform_config
4-
from pydantic import BaseModel, field_validator, Field
4+
from compressed_tensors.transform import (
5+
TransformArgs,
6+
TransformConfig,
7+
TransformScheme,
8+
apply_transform_config,
9+
)
10+
from pydantic import BaseModel, Field, field_validator
11+
from transformers import PreTrainedModel
512

613
from llmcompressor.core import Event, EventType, State
714
from llmcompressor.modeling import fuse_norm_linears
815
from llmcompressor.modifiers import Modifier
9-
from enum import Enum
10-
11-
from transformers import PreTrainedModel
1216

1317

1418
class SpinQuantMappings(BaseModel):
@@ -29,9 +33,10 @@ class SpinQuantMappings(BaseModel):
2933
def cast_to_list(cls, value):
3034
if isinstance(value, str):
3135
return [value]
32-
36+
3337
return value
34-
38+
39+
3540
class NormMapping(BaseModel):
3641
norm: str
3742
linears: List[str]
@@ -40,22 +45,18 @@ class NormMapping(BaseModel):
4045
def cast_to_list(cls, value):
4146
if isinstance(value, str):
4247
return [value]
43-
44-
return value
4548

49+
return value
4650

4751

4852
llama_spinquant = SpinQuantMappings(
4953
embedding="re:.*embed_tokens$",
50-
5154
attn_q="re:.*q_proj$",
5255
attn_k="re:.*k_proj$",
5356
attn_v="re:.*v_proj$",
5457
attn_o="re:.*o_proj$",
55-
5658
mlp_in=["re:.*up_proj$", "re:.*gate_proj$"],
5759
mlp_out="re:.*down_proj$",
58-
5960
lm_head="lm_head",
6061
)
6162

@@ -67,25 +68,31 @@ def cast_to_list(cls, value):
6768
NormMapping(
6869
norm="re:.*post_attention_layernorm$",
6970
linears=["re:.*up_proj$", "re:.*gate_proj$"],
70-
)
71+
),
7172
]
7273

74+
7375
class SpinquantRotation(Enum):
7476
R1 = "R1"
7577
R2 = "R2"
7678
R3 = "R3"
7779
R4 = "R4"
7880

81+
7982
class SpinQuantModifier(Modifier):
8083
rotations: Iterable[SpinquantRotation] = ("R1", "R2")
81-
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(default="hadamard")
84+
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
85+
default="hadamard"
86+
)
8287
randomize: bool = Field(default=False)
8388
learnable: bool = Field(default=False)
8489

8590
mappings: Optional[SpinQuantMappings] = None
8691
norm_mappings: Optional[List[NormMapping]] = None
87-
88-
transform_config: Optional[TransformConfig] = None # optional override for more fine-grained control
92+
93+
transform_config: Optional[TransformConfig] = (
94+
None # optional override for more fine-grained control
95+
)
8996

9097
@field_validator("rotations", mode="before")
9198
def validate_rotations(cls, value):
@@ -101,7 +108,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
101108
if self.transform_config is not None:
102109
if self.mappings is not None:
103110
raise ValueError()
104-
111+
105112
return True
106113

107114
config_groups = {}
@@ -129,6 +136,7 @@ def on_start(self, state: State, event: Event, **kwargs):
129136
# Embedding fusion
130137
# theoretically, doesn't do anything. Doesn't seem to help model sanity either
131138
from compressed_tensors import update_offload_parameter
139+
132140
for W in [state.model.model.embed_tokens]:
133141
W_ = W.weight.data.double()
134142
W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)
@@ -138,16 +146,24 @@ def on_start(self, state: State, event: Event, **kwargs):
138146
# TODO: use norm mappings
139147
# layer norm fusion
140148
for layer in state.model.model.layers:
141-
fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj))
142-
fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj))
149+
fuse_norm_linears(
150+
layer.input_layernorm,
151+
(
152+
layer.self_attn.q_proj,
153+
layer.self_attn.k_proj,
154+
layer.self_attn.v_proj,
155+
),
156+
)
157+
fuse_norm_linears(
158+
layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj)
159+
)
160+
161+
fuse_norm_linears(state.model.model.norm, (state.model.lm_head,))
143162

144163
# needs to happen after the model has been hooked to execute on the GPU
145164
# otherwise we're applying weight transforms on CPU
146165
apply_transform_config(state.model, self.transform_config)
147166

148-
149-
150-
151167
def on_event(self, state: State, event: Event, **kwargs):
152168
if event.type_ == EventType.CALIBRATION_EPOCH_START:
153169
if not self.started_:
@@ -169,7 +185,6 @@ def on_finalize(self, state: State, **kwargs) -> bool:
169185

170186
return True
171187

172-
173188
def _create_r1_scheme(self) -> TransformScheme:
174189
return TransformScheme(
175190
type=self.transform_type,
@@ -190,14 +205,14 @@ def _create_r1_scheme(self) -> TransformScheme:
190205
self.mappings.attn_k,
191206
self.mappings.attn_v,
192207
*self.mappings.mlp_in,
193-
self.mappings.lm_head
208+
self.mappings.lm_head,
194209
],
195210
location="weight_input",
196211
inverse=True,
197212
),
198-
]
213+
],
199214
)
200-
215+
201216
def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
202217
config = model.config
203218

@@ -207,7 +222,7 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
207222
head_dim = config.hidden_size // config.num_attention_heads
208223
else:
209224
raise NotImplementedError()
210-
225+
211226
return TransformScheme(
212227
type=self.transform_type,
213228
randomize=self.randomize,
@@ -223,10 +238,8 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
223238
],
224239
)
225240

226-
227241
def _create_r3_scheme(self) -> TransformScheme:
228242
raise NotImplementedError()
229243

230-
231244
def _create_r4_scheme(self) -> TransformScheme:
232-
raise NotImplementedError()
245+
raise NotImplementedError()

0 commit comments

Comments
 (0)