Skip to content

Commit d910816

Browse files
[Bugfix] Automatically Detect SparseML models (#5119)
1 parent 87d41c8 commit d910816

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

vllm/config.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,19 +156,31 @@ def _verify_embedding_mode(self) -> None:
156156
self.embedding_mode = any(
157157
ModelRegistry.is_embedding_model(arch) for arch in architectures)
158158

159+
def _parse_quant_hf_config(self):
160+
quant_cfg = getattr(self.hf_config, "quantization_config", None)
161+
if quant_cfg is None:
162+
# SparseML uses a "compression_config" with a "quantization_config".
163+
compression_cfg = getattr(self.hf_config, "compression_config",
164+
None)
165+
if compression_cfg is not None:
166+
quant_cfg = compression_cfg.get("quantization_config", None)
167+
168+
return quant_cfg
169+
159170
def _verify_quantization(self) -> None:
160171
supported_quantization = [*QUANTIZATION_METHODS]
161172
rocm_supported_quantization = ["gptq", "squeezellm"]
162173
if self.quantization is not None:
163174
self.quantization = self.quantization.lower()
164175

165176
# Parse quantization method from the HF model config, if available.
166-
quant_cfg = getattr(self.hf_config, "quantization_config", None)
177+
quant_cfg = self._parse_quant_hf_config()
178+
167179
if quant_cfg is not None:
168180
quant_method = quant_cfg.get("quant_method", "").lower()
169181

170182
# Detect which checkpoint is it
171-
for name, method in QUANTIZATION_METHODS.items():
183+
for _, method in QUANTIZATION_METHODS.items():
172184
quantization_override = method.override_quantization_method(
173185
quant_cfg, self.quantization)
174186
if quantization_override:

0 commit comments

Comments
 (0)