|
10 | 10 | from torch.nn.parameter import Parameter
|
11 | 11 |
|
12 | 12 | from vllm import _custom_ops as ops
|
| 13 | +from vllm.model_executor.layers.fused_moe.layer import FusedMoE |
13 | 14 | from vllm.model_executor.layers.linear import LinearMethodBase
|
14 | 15 | from vllm.model_executor.layers.quantization import QuantizationMethods
|
15 | 16 | from vllm.model_executor.layers.quantization.base_config import (
|
16 |
| - QuantizationConfig) |
| 17 | + QuantizationConfig, QuantizeMethodBase) |
17 | 18 | from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
18 | 19 | get_linear_quant_method)
|
19 | 20 | from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
@@ -110,8 +111,23 @@ def from_config(cls, config: dict[str, Any]) -> "GPTQConfig":
|
110 | 111 | return cls(weight_bits, group_size, desc_act, lm_head_quantized,
|
111 | 112 | dynamic)
|
112 | 113 |
|
113 |
| - def get_quant_method(self, layer: torch.nn.Module, |
114 |
| - prefix: str) -> Optional["GPTQLinearMethod"]: |
| 114 | + def get_quant_method( |
| 115 | + self, layer: torch.nn.Module, prefix: str |
| 116 | + ) -> Optional[Union["GPTQLinearMethod", "QuantizeMethodBase"]]: |
| 117 | + if isinstance(layer, FusedMoE): |
| 118 | + # GPTQ MoE support: fall back to MoeWNA16 for broad compatibility |
| 119 | + from .moe_wna16 import MoeWNA16Config |
| 120 | + |
| 121 | + config = { |
| 122 | + "quant_method": "gptq", |
| 123 | + "bits": self.weight_bits, |
| 124 | + "group_size": self.group_size, |
| 125 | + "sym": True, # GPTQ typically uses symmetric quantization |
| 126 | + "lm_head": False, |
| 127 | + } |
| 128 | + return MoeWNA16Config.from_config(config).get_quant_method( |
| 129 | + layer, prefix) |
| 130 | + |
115 | 131 | return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
116 | 132 |
|
117 | 133 |
|
|
0 commit comments