File tree Expand file tree Collapse file tree 2 files changed +4
-10
lines changed
src/llmcompressor/modifiers/transform/spinquant Expand file tree Collapse file tree 2 files changed +4
-10
lines changed Original file line number Diff line number Diff line change @@ -177,9 +177,7 @@ def _create_r1_scheme(self) -> TransformScheme:
177
177
),
178
178
TransformArgs (
179
179
targets = [
180
- self .mappings .attn_q ,
181
- self .mappings .attn_k ,
182
- self .mappings .attn_v ,
180
+ * self .mappings .attn_qkv ,
183
181
* self .mappings .mlp_in ,
184
182
self .mappings .lm_head ,
185
183
],
Original file line number Diff line number Diff line change 10
10
class SpinQuantMapping (BaseModel ):
11
11
embedding : str
12
12
13
- attn_q : str
14
- attn_k : str
15
- attn_v : str
13
+ attn_qkv : List [str ] # q_proj, k_proj, v_proj
16
14
attn_o : str
17
15
attn_head_dim : Optional [int ] = Field (default = None )
18
16
@@ -21,7 +19,7 @@ class SpinQuantMapping(BaseModel):
21
19
22
20
lm_head : str
23
21
24
- @field_validator ("mlp_in" , "mlp_out" , mode = "before" )
22
+ @field_validator ("attn_qkv" , " mlp_in" , "mlp_out" , mode = "before" )
25
23
def cast_to_list (cls , value ):
26
24
if isinstance (value , str ):
27
25
return [value ]
@@ -31,9 +29,7 @@ def cast_to_list(cls, value):
31
29
32
30
_default_mappings = SpinQuantMapping (
33
31
embedding = "re:.*embed_tokens$" ,
34
- attn_q = "re:.*q_proj$" ,
35
- attn_k = "re:.*k_proj$" ,
36
- attn_v = "re:.*v_proj$" ,
32
+ attn_qkv = ["re:.*q_proj$" ,"re:.*k_proj$" ,"re:.*v_proj$" ],
37
33
attn_o = "re:.*o_proj$" ,
38
34
mlp_in = ["re:.*up_proj$" , "re:.*gate_proj$" ],
39
35
mlp_out = "re:.*down_proj$" ,
You can’t perform that action at this time.
0 commit comments