Skip to content

Commit 3654847

Browse files
authored
feat: Add Support GPTQ Quantization MOE on ROCM vllm serve (#21733)
1 parent eefbf4a commit 3654847

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,8 +761,8 @@ def get_moe_wna16_block_config(config: dict[str,
761761

762762
def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int,
763763
num_experts: int, bit: int):
764-
return bit == 4 and group_size in [32, 64, 128] and \
765-
num_valid_tokens / num_experts <= 6
764+
return current_platform.is_cuda() and bit == 4 and \
765+
group_size in [32, 64, 128] and num_valid_tokens / num_experts <= 6
766766

767767

768768
def get_default_config(

vllm/model_executor/layers/quantization/gptq.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from torch.nn.parameter import Parameter
1111

1212
from vllm import _custom_ops as ops
13+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
1314
from vllm.model_executor.layers.linear import LinearMethodBase
1415
from vllm.model_executor.layers.quantization import QuantizationMethods
1516
from vllm.model_executor.layers.quantization.base_config import (
16-
QuantizationConfig)
17+
QuantizationConfig, QuantizeMethodBase)
1718
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
1819
get_linear_quant_method)
1920
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
@@ -110,8 +111,23 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
110111
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
111112
dynamic)
112113

113-
def get_quant_method(self, layer: torch.nn.Module,
114-
prefix: str) -> Optional["GPTQLinearMethod"]:
114+
def get_quant_method(
115+
self, layer: torch.nn.Module, prefix: str
116+
) -> Optional[Union["GPTQLinearMethod", "QuantizeMethodBase"]]:
117+
if isinstance(layer, FusedMoE):
118+
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
119+
from .moe_wna16 import MoeWNA16Config
120+
121+
config = {
122+
"quant_method": "gptq",
123+
"bits": self.weight_bits,
124+
"group_size": self.group_size,
125+
"sym": True, # GPTQ typically uses symmetric quantization
126+
"lm_head": False,
127+
}
128+
return MoeWNA16Config.from_config(config).get_quant_method(
129+
layer, prefix)
130+
115131
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
116132

117133

0 commit comments

Comments
 (0)