Skip to content

Commit 7d4e1b8

Browse files
QubitiumRobert Shaw
andauthored
[Misc] Add support for new autogptq checkpoint_format (#3689)
Co-authored-by: Robert Shaw <[email protected]>
1 parent 93deb0b commit 7d4e1b8

File tree

2 files changed

+83
-13
lines changed

2 files changed

+83
-13
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Tests whether Marlin models can be loaded from the autogptq config.
2+
3+
Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`.
4+
"""
5+
6+
from dataclasses import dataclass
7+
8+
import pytest
9+
10+
from vllm.config import ModelConfig
11+
12+
13+
@dataclass
14+
class ModelPair:
15+
model_marlin: str
16+
model_gptq: str
17+
18+
19+
# Model Id // Expected Kernel
20+
MODELS_QUANT_TYPE = [
21+
# compat: autogptq <=0.7.1 is_marlin_format: bool
22+
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"),
23+
("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"),
24+
# compat: autogptq >=0.8.0 use checkpoint_format: str
25+
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"),
26+
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq")
27+
]
28+
29+
30+
@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE)
31+
def test_auto_gptq(model_quant_type: str, ) -> None:
32+
model_path, quant_type = model_quant_type
33+
34+
model_config_no_quant_arg = ModelConfig(
35+
model_path,
36+
model_path,
37+
tokenizer_mode="auto",
38+
trust_remote_code=False,
39+
download_dir=None,
40+
load_format="dummy",
41+
seed=0,
42+
dtype="float16",
43+
revision=None,
44+
quantization=None # case 1
45+
)
46+
47+
model_config_quant_arg = ModelConfig(
48+
model_path,
49+
model_path,
50+
tokenizer_mode="auto",
51+
trust_remote_code=False,
52+
download_dir=None,
53+
load_format="dummy",
54+
seed=0,
55+
dtype="float16",
56+
revision=None,
57+
quantization="gptq" # case 2
58+
)
59+
60+
assert model_config_no_quant_arg.quantization == quant_type, (
61+
f"Expected quant_type == {quant_type} for {model_path}, "
62+
f"but found {model_config_no_quant_arg.quantization} "
63+
"for no --quantization None case")
64+
65+
assert model_config_quant_arg.quantization == quant_type, (
66+
f"Expected quant_type == {quant_type} for {model_path}, "
67+
f"but found {model_config_quant_arg.quantization} "
68+
"for --quantization gptq case")

vllm/config.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -171,26 +171,28 @@ def _verify_quantization(self) -> None:
171171
self.quantization = self.quantization.lower()
172172

173173
# Parse quantization method from the HF model config, if available.
174-
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
175-
if hf_quant_config is not None:
176-
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
177-
178-
# If the GPTQ model is serialized in marlin format, use marlin.
179-
if (hf_quant_method == "gptq"
180-
and "is_marlin_format" in hf_quant_config
181-
and hf_quant_config["is_marlin_format"]):
174+
quant_cfg = getattr(self.hf_config, "quantization_config", None)
175+
if quant_cfg is not None:
176+
quant_method = quant_cfg.get("quant_method", "").lower()
177+
# compat: autogptq >=0.8.0 use checkpoint_format: str
178+
# compat: autogptq <=0.7.1 is_marlin_format: bool
179+
is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
180+
or quant_cfg.get("is_marlin_format", False))
181+
182+
# Use marlin if the GPTQ model is serialized in marlin format.
183+
if quant_method == "gptq" and is_format_marlin:
182184
logger.info("The model is serialized in Marlin format. "
183185
"Using Marlin kernel.")
184-
hf_quant_method = "marlin"
186+
quant_method = "marlin"
185187
if self.quantization == "gptq":
186-
self.quantization = hf_quant_method
188+
self.quantization = quant_method
187189

188190
if self.quantization is None:
189-
self.quantization = hf_quant_method
190-
elif self.quantization != hf_quant_method:
191+
self.quantization = quant_method
192+
elif self.quantization != quant_method:
191193
raise ValueError(
192194
"Quantization method specified in the model config "
193-
f"({hf_quant_method}) does not match the quantization "
195+
f"({quant_method}) does not match the quantization "
194196
f"method specified in the `quantization` argument "
195197
f"({self.quantization}).")
196198

0 commit comments

Comments
 (0)