Skip to content

Commit 183a709

Browse files
JartXIsotr0py
andauthored
[BUGFIX] GPTQ quantization compatibility for Qwen3 MOE models (AutoGPTQ and AutoRound-GPTQ) (#23994)
Signed-off-by: JartX <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent 14b4326 commit 183a709

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

vllm/model_executor/layers/quantization/gptq.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
desc_act: bool,
3838
lm_head_quantized: bool,
3939
dynamic: dict[str, dict[str, Union[int, bool]]],
40+
autoround_version: str = "",
4041
) -> None:
4142
# GPTQModel use `dynamic` config property to allow per module
4243
# quantization config so each module can be individually optimized.
@@ -74,6 +75,9 @@ def __init__(
7475
"Currently, only 2/3/4/8-bit weight quantization is "
7576
f"supported for GPTQ, but got {self.weight_bits} bits.")
7677

78+
# used to identify GPTQ model quantized by autoround
79+
self.autoround_version = autoround_version
80+
7781
def __repr__(self) -> str:
7882
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
7983
f"group_size={self.group_size}, "
@@ -108,8 +112,10 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
108112
desc_act = cls.get_from_keys(config, ["desc_act"])
109113
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
110114
default=False)
115+
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
116+
default="")
111117
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
112-
dynamic)
118+
dynamic, autoround_version)
113119

114120
def get_quant_method(
115121
self, layer: torch.nn.Module, prefix: str

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
119119

120120
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
121121

122+
# used to identify GPTQ model quantized by autoround
123+
self.autoround_version = full_config.get("autoround_version", "")
124+
122125
def __repr__(self) -> str:
123126
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
124127
f"group_size={self.group_size}, "

vllm/model_executor/models/qwen3_moe.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,13 @@ def __init__(
159159

160160
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
161161
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
162-
# seems to avoid gate quantization.
163-
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
164-
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
162+
# seems to avoid gate quantization while AutoRound does.
163+
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4,
164+
# and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq
165+
if isinstance(
166+
quant_config,
167+
(GPTQConfig,
168+
GPTQMarlinConfig)) and not quant_config.autoround_version:
165169
return None
166170
return quant_config
167171

0 commit comments

Comments
 (0)