Skip to content

Commit 1ece7f3

Browse files
Jun-HowieJunHowiegemini-code-assist[bot]
authored
Fix: AWQ Marlin get_quant_method does not recognize "modules_to_not_convert" (#21888)
Signed-off-by: JunHowie <[email protected]> Co-authored-by: JunHowie <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent bc8372e commit 1ece7f3

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from vllm import _custom_ops as ops
1111
from vllm.logger import init_logger
1212
from vllm.model_executor.layers.fused_moe.layer import (
13-
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
13+
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
14+
UnquantizedFusedMoEMethod)
1415
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1516
UnquantizedLinearMethod,
1617
set_weight_attrs)
@@ -141,6 +142,9 @@ def get_quant_method(self, layer: torch.nn.Module,
141142
elif isinstance(layer, FusedMoE):
142143
from vllm.model_executor.layers.quantization.moe_wna16 import (
143144
MoeWNA16Config)
145+
if is_layer_skipped_awq(
146+
prefix, getattr(self, "modules_to_not_convert", [])):
147+
return UnquantizedFusedMoEMethod(layer.moe_config)
144148
if not check_moe_marlin_supports_layer(layer, self.group_size):
145149
logger.warning_once(
146150
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
@@ -520,4 +524,4 @@ def apply(
520524
expert_map=expert_map,
521525
w1_zeros=layer.w13_qzeros,
522526
w2_zeros=layer.w2_qzeros,
523-
workspace=layer.workspace)
527+
workspace=layer.workspace)

0 commit comments

Comments
 (0)