@@ -55,24 +55,34 @@ class SpinQuantModifier(Modifier, use_enum_values=True):
55
55
A mapping will be inferred if None is provided
56
56
:param norm_mappings: Specifies layers within a model to target for norm fusing.
57
57
A mapping will be inferred if None is provided
58
- :param transform_config: Optional transform config which overrides `mappings`
58
+ :param transform_config: Optional transform config for overriding provided arguments
59
59
"""
60
60
61
- rotations : List [SpinquantRotation ] = Field (default_factory = lambda : ["R1" , "R2" ])
61
+ rotations : List [SpinquantRotation ] = Field (
62
+ default_factory = lambda : ["R1" , "R2" ], exclude = True
63
+ )
62
64
transform_type : Literal ["hadamard" , "random-hadamard" , "random-matrix" ] = Field (
63
- default = "hadamard"
65
+ default = "hadamard" , exclude = True
64
66
)
65
- randomize : bool = Field (default = False )
66
- learnable : bool = Field (default = False )
67
+ randomize : bool = Field (default = False , exclude = True )
68
+ learnable : bool = Field (default = False , exclude = True )
67
69
68
70
# norm mappings separate from spinquant mappings to allow users to
69
71
# override spinquant mappings with transform_config without overriding norms
70
- mappings : Optional [SpinQuantMapping ] = Field (default = None , exclude = True )
71
- norm_mappings : Optional [List [NormMapping ]] = Field (default = None , exclude = True )
72
+ mappings : Optional [SpinQuantMapping ] = Field (
73
+ default = None ,
74
+ repr = False ,
75
+ exclude = True ,
76
+ )
77
+ norm_mappings : Optional [List [NormMapping ]] = Field (
78
+ default = None ,
79
+ repr = False ,
80
+ exclude = True ,
81
+ )
72
82
73
83
# optional override for more fine-grained control
74
84
# also included in recipe serialization
75
- transform_config : Optional [TransformConfig ] = Field (default = None )
85
+ transform_config : Optional [TransformConfig ] = Field (default = None , repr = False )
76
86
77
87
@field_validator ("randomize" , "learnable" , mode = "before" )
78
88
def validate_not_implemented (cls , value , info : ValidationInfo ):
@@ -86,12 +96,6 @@ def validate_rotations(cls, value):
86
96
87
97
def on_initialize (self , state : State , ** kwargs ) -> bool :
88
98
if self .transform_config is not None :
89
- if self .mappings is not None :
90
- raise ValueError (
91
- "Please provide either `transform_config` or `mappings` "
92
- "but not both"
93
- )
94
-
95
99
return True
96
100
97
101
self .mappings = infer_mapping_from_model (state .model )
0 commit comments