Skip to content

Commit 7e6ea83

Browse files
committed
R4
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 2d54f29 commit 7e6ea83

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,17 @@ def on_initialize(self, state: State, **kwargs) -> bool:
119119

120120
self.mappings = infer_mapping_from_model(state.model)
121121
self.norm_mappings = infer_norm_mapping_from_model(state.model)
122+
head_dim = self._infer_head_dim(state.model)
122123

123124
config_groups = {}
124125
if SpinquantRotation.R1 in self.rotations:
125126
config_groups["R1"] = self._create_r1_scheme()
126127

127128
if SpinquantRotation.R2 in self.rotations:
128-
config_groups["R2"] = self._create_r2_scheme(state.model)
129+
config_groups["R2"] = self._create_r2_scheme(head_dim)
129130

130131
if SpinquantRotation.R3 in self.rotations:
131-
config_groups["R3"] = self._create_r3_scheme()
132+
config_groups["R3"] = self._create_r3_scheme(head_dim)
132133

133134
if SpinquantRotation.R4 in self.rotations:
134135
config_groups["R4"] = self._create_r4_scheme()
@@ -209,16 +210,7 @@ def _create_r1_scheme(self) -> TransformScheme:
209210
],
210211
)
211212

212-
def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
213-
config = model.config
214-
215-
if hasattr(config, "head_dim"):
216-
head_dim = config.head_dim
217-
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
218-
head_dim = config.hidden_size // config.num_attention_heads
219-
else:
220-
raise NotImplementedError()
221-
213+
def _create_r2_scheme(self, head_dim: int) -> TransformScheme:
222214
return TransformScheme(
223215
type=self.transform_type,
224216
randomize=self.randomize,
@@ -235,9 +227,23 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
235227
],
236228
)
237229

238-
def _create_r3_scheme(self) -> TransformScheme:
239-
raise NotImplementedError(
240-
"SpinQuant R3 rotations will be added in a future release"
230+
def _create_r3_scheme(self, head_dim: int) -> TransformScheme:
231+
return TransformScheme(
232+
type=self.transform_type,
233+
randomize=self.randomize,
234+
requires_grad=self.learnable,
235+
precision=self.precision,
236+
head_dim=head_dim,
237+
apply=[
238+
TransformArgs(
239+
targets=[self.mappings.attn],
240+
location="q_attn",
241+
),
242+
TransformArgs(
243+
targets=[self.mappings.attn],
244+
location="k_cache",
245+
),
246+
],
241247
)
242248

243249
def _create_r4_scheme(self) -> TransformScheme:
@@ -258,3 +264,13 @@ def _create_r4_scheme(self) -> TransformScheme:
258264
),
259265
],
260266
)
267+
268+
def _infer_head_dim(self, model: PreTrainedModel) -> int:
269+
config = model.config
270+
271+
if hasattr(config, "head_dim"):
272+
return config.head_dim
273+
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
274+
return config.hidden_size // config.num_attention_heads
275+
else:
276+
raise NotImplementedError()

src/llmcompressor/modifiers/transform/spinquant/mappings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class SpinQuantMapping(BaseModel):
2929

3030
embedding: str
3131

32+
attn: str
3233
attn_q: str
3334
attn_k: str
3435
attn_v: str
@@ -50,6 +51,7 @@ def cast_to_list(cls, value):
5051

5152
_default_mappings = SpinQuantMapping(
5253
embedding="re:.*embed_tokens$",
54+
attn="re:.*self_attn$",
5355
attn_q="re:.*q_proj$",
5456
attn_k="re:.*k_proj$",
5557
attn_v="re:.*v_proj$",

0 commit comments

Comments
 (0)