Skip to content

Commit 7172c26

Browse files
committed
spinquant modifier
Signed-off-by: Kyle Sayers <[email protected]>
1 parent dc5c30c commit 7172c26

File tree

10 files changed

+221
-80
lines changed

10 files changed

+221
-80
lines changed

examples/transform/llama3_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from llmcompressor import oneshot
55
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
6-
from llmcompressor.modifiers.transform import TransformModifier
6+
from llmcompressor.modifiers.transform import SpinQuantModifier
77
from llmcompressor.utils import dispatch_for_generation
88

99
# Select model and load it.
@@ -62,7 +62,7 @@ def tokenize(sample):
6262
# TODO preset_config="QUIP_ONLINE" outputs gibberish
6363
# preset_config="QUIP" output sensible, but cannot load saved
6464
# checkpoint or run evals (~4hrs to run)
65-
TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"),
65+
SpinQuantModifier(rotations=["R1", "R2"]),
6666
# QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
6767
]
6868

examples/transform/spinquant_dummy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from compressed_tensors.utils import update_parameter_data
55
from llmcompressor import oneshot
66
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
7-
from llmcompressor.modifiers.transform import TransformModifier
7+
from llmcompressor.modifiers.transform import SpinQuantModifier
88
from llmcompressor.utils import dispatch_for_generation
99
from transformers.models.llama.modeling_llama import (
1010
LlamaRMSNorm,
@@ -94,7 +94,7 @@ def forward(self, input_ids):
9494
recipe = [
9595
# NOTE: preset_config="QUIP" output sensible, but cannot load saved
9696
# checkpoint or run evals (~4hrs to run)
97-
TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"),
97+
SpinQuantModifier(rotations=["R1", "R2"]),
9898
# QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
9999
]
100100

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
# flake8: noqa
22

3-
from .presets import TRANSFORM_PRESETS
4-
from .transform import TransformModifier
3+
from .spinquant import SpinQuantModifier

src/llmcompressor/modifiers/transform/presets/__init__.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

src/llmcompressor/modifiers/transform/quip/base.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .base import *
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from typing import Optional, List, Literal
2+
3+
from compressed_tensors.transform import TransformConfig, TransformScheme, TransformArgs, apply_transform_config
4+
from pydantic import BaseModel, field_validator, Field
5+
6+
from llmcompressor.core import Event, EventType, State
7+
from llmcompressor.modeling import fuse_norm_linears
8+
from llmcompressor.modifiers import Modifier
9+
from enum import Enum
10+
11+
from transformers import PreTrainedModel
12+
13+
14+
class SpinQuantMappings(BaseModel):
15+
embedding: str
16+
17+
attn_q: str
18+
attn_k: str
19+
attn_v: str
20+
attn_o: str
21+
attn_head_dim: Optional[int] = Field(default=None)
22+
23+
mlp_in: List[str] # up_proj, gate_proj
24+
mlp_out: List[str] # down_proj
25+
26+
lm_head: str
27+
28+
@field_validator("mlp_in", "mlp_out", mode="before")
29+
def cast_to_list(cls, value):
30+
if isinstance(value, str):
31+
return [value]
32+
33+
return value
34+
35+
class NormMapping(BaseModel):
36+
norm: str
37+
linears: List[str]
38+
39+
@field_validator("linears", mode="before")
40+
def cast_to_list(cls, value):
41+
if isinstance(value, str):
42+
return [value]
43+
44+
return value
45+
46+
47+
48+
llama_spinquant = SpinQuantMappings(
49+
embedding="re:.*embed_tokens$",
50+
51+
attn_q="re:.*q_proj$",
52+
attn_k="re:.*k_proj$",
53+
attn_v="re:.*v_proj$",
54+
attn_o="re:.*o_proj$",
55+
56+
mlp_in=["re:.*up_proj$", "re:.*gate_proj$"],
57+
mlp_out="re:.*down_proj$",
58+
59+
lm_head="lm_head",
60+
)
61+
62+
llama_norm_mappings = [
63+
NormMapping(
64+
norm="re:.*input_layernorm$",
65+
linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"],
66+
),
67+
NormMapping(
68+
norm="re:.*post_attention_layernorm$",
69+
linears=["re:.*up_proj$", "re:.*gate_proj$"],
70+
)
71+
]
72+
73+
class SpinquantRotation(Enum):
74+
R1 = "R1"
75+
R2 = "R2"
76+
R3 = "R3"
77+
R4 = "R4"
78+
79+
class SpinQuantModifier(Modifier):
80+
rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"])
81+
82+
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(default="hadamard")
83+
randomize: bool = Field(default=False)
84+
learnable: bool = Field(default=False)
85+
86+
mappings: Optional[SpinQuantMappings] = None
87+
norm_mappings: Optional[List[NormMapping]] = None
88+
89+
transform_config: Optional[TransformConfig] = None # optional override for more fine-grained control
90+
91+
def on_initialize(self, state: State, **kwargs) -> bool:
92+
# HARDCODE
93+
self.mappings = llama_spinquant
94+
self.norm_mappings = llama_norm_mappings
95+
96+
if self.transform_config is not None:
97+
if self.mappings is not None:
98+
raise ValueError()
99+
100+
return True
101+
102+
config_groups = {}
103+
for rotation in self.rotations:
104+
if rotation == SpinquantRotation.R1:
105+
config_groups["R1"] = self._create_r1_scheme()
106+
107+
if rotation == SpinquantRotation.R2:
108+
config_groups["R2"] = self._create_r2_scheme(state.model)
109+
110+
if rotation == SpinquantRotation.R3:
111+
config_groups["R3"] = self._create_r3_scheme()
112+
113+
if rotation == SpinquantRotation.R4:
114+
config_groups["R4"] = self._create_r4_scheme()
115+
116+
self.transform_config = TransformConfig(config_groups=config_groups)
117+
118+
return True
119+
120+
def on_start(self, state: State, event: Event, **kwargs):
121+
self.started_ = True
122+
123+
for layer in state.model.model.layers:
124+
fuse_norm_linears(layer.input_layernorm, (layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj))
125+
fuse_norm_linears(layer.post_attention_layernorm, (layer.mlp.gate_proj, layer.mlp.up_proj))
126+
127+
# needs to happen after the model has been hooked to execute on the GPU
128+
# otherwise we're applying weight transforms on CPU
129+
apply_transform_config(state.model, self.transform_config)
130+
131+
132+
133+
134+
def on_event(self, state: State, event: Event, **kwargs):
135+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
136+
if not self.started_:
137+
self.on_start(state, None)
138+
139+
elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
140+
pass
141+
142+
elif event.type_ == EventType.CALIBRATION_EPOCH_END:
143+
if not self.ended_:
144+
self.on_end(state, None)
145+
146+
def on_end(self, state: State, event: Event, **kwargs):
147+
self.ended_ = True
148+
149+
def on_finalize(self, state: State, **kwargs) -> bool:
150+
if not self.ended_:
151+
self.on_end(state, None)
152+
153+
return True
154+
155+
156+
def _create_r1_scheme(self) -> TransformScheme:
157+
return TransformScheme(
158+
type=self.transform_type,
159+
randomize=self.randomize,
160+
requires_grad=self.learnable,
161+
apply=[
162+
TransformArgs(
163+
targets=[
164+
self.mappings.embedding,
165+
self.mappings.attn_o,
166+
*self.mappings.mlp_out,
167+
],
168+
location="weight_output",
169+
),
170+
TransformArgs(
171+
targets=[
172+
self.mappings.attn_q,
173+
self.mappings.attn_k,
174+
self.mappings.attn_v,
175+
*self.mappings.mlp_in,
176+
self.mappings.lm_head
177+
],
178+
location="weight_input",
179+
inverse=True,
180+
),
181+
]
182+
)
183+
184+
def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
185+
config = model.config
186+
187+
if hasattr(config, "head_dim"):
188+
head_dim = config.head_dim
189+
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
190+
head_dim = config.hidden_size // config.num_attention_heads
191+
else:
192+
raise NotImplementedError()
193+
194+
return TransformScheme(
195+
type=self.transform_type,
196+
randomize=self.randomize,
197+
requires_grad=self.learnable,
198+
head_dim=head_dim,
199+
apply=[
200+
TransformArgs(targets=[self.mappings.attn_v], location="weight_output"),
201+
TransformArgs(
202+
targets=[self.mappings.attn_o],
203+
location="weight_input",
204+
inverse=True,
205+
),
206+
],
207+
)
208+
209+
210+
def _create_r3_scheme(self) -> TransformScheme:
211+
raise NotImplementedError()
212+
213+
214+
def _create_r4_scheme(self) -> TransformScheme:
215+
raise NotImplementedError()

src/llmcompressor/modifiers/transform/transform.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

0 commit comments

Comments
 (0)