@@ -143,23 +143,45 @@ def compile_model(args, model, config, tokenizer):
143143 print (f" -- Writing config.json" )
144144 with open (os .path .join (in_dir , "config.json" ), "r" ) as f :
145145 config_dict = json .load (f )
146- if "quantization_config" in config_dict :
147- qcfg = config_dict ["quantization_config" ]
148- qcfg ["bits" ] = args ["bits" ]
149- qcfg ["head_bits" ] = args ["head_bits" ]
150- else :
151- qcfg = {
152- "quant_method" : "exl3" ,
153- "version" : __version__ ,
154- "bits" : args ["bits" ],
155- "head_bits" : args ["head_bits" ],
146+
147+ qcfg = {
148+ "quant_method" : "exl3" ,
149+ "version" : __version__ ,
150+ "bits" : args ["bits" ],
151+ "head_bits" : args ["head_bits" ],
152+ }
153+ if "cal_rows" in args :
154+ qcfg .update ({
156155 "calibration" : {
157156 "rows" : args ["cal_rows" ],
158157 "cols" : args ["cal_cols" ],
159- },
160- "out_scales" : {True : "always" , False : "never" , None : "auto" }[args ["apply_out_scales" ]],
161- "codebook" : args ["codebook" ],
162- }
158+ }
159+ })
160+ if "apply_out_scales" in args :
161+ qcfg .update ({
162+ "out_scales" : {True : "always" , False : "never" , None : "auto" }[args ["apply_out_scales" ]]
163+ })
164+ if "codebook" in args :
165+ qcfg .update ({
166+ "codebook" : args ["codebook" ]
167+ })
168+
169+ if "quantization_config" in config_dict :
170+ orig_qcfg = config_dict ["quantization_config" ].copy ()
171+ if orig_qcfg .get ("quant_method" ) == "exl3" :
172+ qcfg = orig_qcfg
173+ qcfg .update ({
174+ "bits" : args ["bits" ],
175+ "head_bits" : args ["head_bits" ],
176+ })
177+ if "codebook" in args :
178+ qcfg .update ({
179+ "codebook" : args ["codebook" ]
180+ })
181+ else :
182+ qcfg .update ({
183+ "original_quantization_config" : orig_qcfg
184+ })
163185
164186 update_config (config_dict )
165187 config_dict ["quantization_config" ] = qcfg
0 commit comments