@@ -156,19 +156,31 @@ def _verify_embedding_mode(self) -> None:
156
156
self .embedding_mode = any (
157
157
ModelRegistry .is_embedding_model (arch ) for arch in architectures )
158
158
159
+ def _parse_quant_hf_config (self ):
160
+ quant_cfg = getattr (self .hf_config , "quantization_config" , None )
161
+ if quant_cfg is None :
162
+ # SparseML uses a "compression_config" with a "quantization_config".
163
+ compression_cfg = getattr (self .hf_config , "compression_config" ,
164
+ None )
165
+ if compression_cfg is not None :
166
+ quant_cfg = compression_cfg .get ("quantization_config" , None )
167
+
168
+ return quant_cfg
169
+
159
170
def _verify_quantization (self ) -> None :
160
171
supported_quantization = [* QUANTIZATION_METHODS ]
161
172
rocm_supported_quantization = ["gptq" , "squeezellm" ]
162
173
if self .quantization is not None :
163
174
self .quantization = self .quantization .lower ()
164
175
165
176
# Parse quantization method from the HF model config, if available.
166
- quant_cfg = getattr (self .hf_config , "quantization_config" , None )
177
+ quant_cfg = self ._parse_quant_hf_config ()
178
+
167
179
if quant_cfg is not None :
168
180
quant_method = quant_cfg .get ("quant_method" , "" ).lower ()
169
181
170
182
# Detect which checkpoint is it
171
- for name , method in QUANTIZATION_METHODS .items ():
183
+ for _ , method in QUANTIZATION_METHODS .items ():
172
184
quantization_override = method .override_quantization_method (
173
185
quant_cfg , self .quantization )
174
186
if quantization_override :
0 commit comments