1
1
from enum import Enum
2
2
from typing import Iterable , List , Literal , Optional
3
3
4
- from compressed_tensors import match_named_modules , is_match
4
+ from compressed_tensors import is_match , match_named_modules
5
5
from compressed_tensors .transform import (
6
6
TransformArgs ,
7
7
TransformConfig ,
8
8
TransformScheme ,
9
9
apply_transform_config ,
10
10
)
11
- from pydantic import BaseModel , Field , field_validator
11
+ from pydantic import Field , field_validator
12
12
from transformers import PreTrainedModel
13
13
14
14
from llmcompressor .core import Event , EventType , State
15
- from llmcompressor .modeling import normalize_embedding , fuse_norm_linears
15
+ from llmcompressor .modeling import fuse_norm_linears , normalize_embedding
16
16
from llmcompressor .modifiers import Modifier
17
17
18
-
19
- class SpinQuantMappings (BaseModel ):
20
- embedding : str
21
-
22
- attn_q : str
23
- attn_k : str
24
- attn_v : str
25
- attn_o : str
26
- attn_head_dim : Optional [int ] = Field (default = None )
27
-
28
- mlp_in : List [str ] # up_proj, gate_proj
29
- mlp_out : List [str ] # down_proj
30
-
31
- lm_head : str
32
-
33
- @field_validator ("mlp_in" , "mlp_out" , mode = "before" )
34
- def cast_to_list (cls , value ):
35
- if isinstance (value , str ):
36
- return [value ]
37
-
38
- return value
39
-
40
-
41
- class NormMapping (BaseModel ):
42
- norm : str
43
- linears : List [str ]
44
-
45
- @field_validator ("linears" , mode = "before" )
46
- def cast_to_list (cls , value ):
47
- if isinstance (value , str ):
48
- return [value ]
49
-
50
- return value
51
-
52
-
53
- llama_spinquant = SpinQuantMappings (
54
- embedding = "re:.*embed_tokens$" ,
55
- attn_q = "re:.*q_proj$" ,
56
- attn_k = "re:.*k_proj$" ,
57
- attn_v = "re:.*v_proj$" ,
58
- attn_o = "re:.*o_proj$" ,
59
- mlp_in = ["re:.*up_proj$" , "re:.*gate_proj$" ],
60
- mlp_out = "re:.*down_proj$" ,
61
- lm_head = "lm_head" ,
62
- )
63
-
64
- llama_norm_mappings = [
65
- NormMapping (
66
- norm = "re:.*input_layernorm$" ,
67
- linears = ["re:.*q_proj$" , "re:.*k_proj$" , "re:.*v_proj$" ],
68
- ),
69
- NormMapping (
70
- norm = "re:.*post_attention_layernorm$" ,
71
- linears = ["re:.*up_proj$" , "re:.*gate_proj$" ],
72
- ),
73
- NormMapping (
74
- norm = "model.norm" ,
75
- linears = ["lm_head" ],
76
- ),
77
- ]
18
+ from .mappings import SPINQUANT_MAPPING_REGISTRY , SpinQuantMappings
19
+ from .norm_mappings import NORM_MAPPING_REGISTRY , NormMapping
78
20
79
21
80
22
class SpinquantRotation (Enum ):
@@ -92,12 +34,15 @@ class SpinQuantModifier(Modifier):
92
34
randomize : bool = Field (default = False )
93
35
learnable : bool = Field (default = False )
94
36
37
+ # norm mappings separate from spinquant mappings to allow users to
38
+ # override spinquant mappings with transform_config without overriding norms
39
+ # we can combine these mappings, but it requires some more validation logic
40
+ # maybe there's a reason to keep if other modifiers want norm fusing, idk
95
41
mappings : Optional [SpinQuantMappings ] = None
96
42
norm_mappings : Optional [List [NormMapping ]] = None
97
43
98
- transform_config : Optional [TransformConfig ] = (
99
- None # optional override for more fine-grained control
100
- )
44
+ # optional override for more fine-grained control
45
+ transform_config : Optional [TransformConfig ] = None
101
46
102
47
@field_validator ("rotations" , mode = "before" )
103
48
def validate_rotations (cls , value ):
@@ -106,9 +51,9 @@ def validate_rotations(cls, value):
106
51
return value
107
52
108
53
def on_initialize (self , state : State , ** kwargs ) -> bool :
109
- # HARDCODE
110
- self .mappings = llama_spinquant
111
- self .norm_mappings = llama_norm_mappings
54
+ # TODO: more validation
55
+ self .mappings = SPINQUANT_MAPPING_REGISTRY [ state . model . __class__ . __name__ ]
56
+ self .norm_mappings = NORM_MAPPING_REGISTRY [ state . model . __class__ . __name__ ]
112
57
113
58
if self .transform_config is not None :
114
59
if self .mappings is not None :
0 commit comments