Skip to content

Commit 11662d7

Browse files
merge attn_q etc. into attn_qkv
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent d0e5bc5 commit 11662d7

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,7 @@ def _create_r1_scheme(self) -> TransformScheme:
177177
),
178178
TransformArgs(
179179
targets=[
180-
self.mappings.attn_q,
181-
self.mappings.attn_k,
182-
self.mappings.attn_v,
180+
*self.mappings.attn_qkv,
183181
*self.mappings.mlp_in,
184182
self.mappings.lm_head,
185183
],

src/llmcompressor/modifiers/transform/spinquant/mappings.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
class SpinQuantMapping(BaseModel):
1111
embedding: str
1212

13-
attn_q: str
14-
attn_k: str
15-
attn_v: str
13+
attn_qkv: List[str] # q_proj, k_proj, v_proj
1614
attn_o: str
1715
attn_head_dim: Optional[int] = Field(default=None)
1816

@@ -21,7 +19,7 @@ class SpinQuantMapping(BaseModel):
2119

2220
lm_head: str
2321

24-
@field_validator("mlp_in", "mlp_out", mode="before")
22+
@field_validator("attn_qkv", "mlp_in", "mlp_out", mode="before")
2523
def cast_to_list(cls, value):
2624
if isinstance(value, str):
2725
return [value]
@@ -31,9 +29,7 @@ def cast_to_list(cls, value):
3129

3230
_default_mappings = SpinQuantMapping(
3331
embedding="re:.*embed_tokens$",
34-
attn_q="re:.*q_proj$",
35-
attn_k="re:.*k_proj$",
36-
attn_v="re:.*v_proj$",
32+
attn_qkv=["re:.*q_proj$","re:.*k_proj$","re:.*v_proj$"],
3733
attn_o="re:.*o_proj$",
3834
mlp_in=["re:.*up_proj$", "re:.*gate_proj$"],
3935
mlp_out="re:.*down_proj$",

0 commit comments

Comments
 (0)