Skip to content

Commit e31b31f

Browse files
authored
[main][Bugfix] Fix unable to load qwen3_moe quantized weights (#2219)
### What this PR does / why we need it? Fixes unable to load `qwen3_moe` quantized weights issue due to #1994 ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? Add a `qwen3_moe` W8A8 quantized model in `tests/e2e/multicard/test_qwen3_moe.py` - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@c494f96 --------- Signed-off-by: zhoux77899 <[email protected]>
1 parent 54ace9e commit e31b31f

File tree

2 files changed

+53
-5
lines changed

2 files changed

+53
-5
lines changed

tests/e2e/multicard/test_qwen3_moe.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
#
1919
"""Compare the short outputs of HF and vLLM when using greedy sampling.
2020
21-
Run `pytest tests/test_offline_inference.py`.
21+
Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
2222
"""
2323

24+
from modelscope import snapshot_download # type: ignore
25+
2426
from tests.e2e.conftest import VllmRunner
2527

2628

@@ -53,3 +55,20 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_EP():
5355
distributed_executor_backend="mp",
5456
) as vllm_model:
5557
vllm_model.generate_greedy(example_prompts, max_tokens)
58+
59+
60+
def test_models_distributed_Qwen3_MOE_W8A8():
61+
example_prompts = [
62+
"Hello, my name is",
63+
]
64+
dtype = "auto"
65+
max_tokens = 5
66+
with VllmRunner(
67+
snapshot_download("vllm-ascend/Qwen3-30B-A3B-W8A8"),
68+
max_model_len=8192,
69+
dtype=dtype,
70+
tensor_parallel_size=2,
71+
quantization="ascend",
72+
enforce_eager=False,
73+
) as vllm_model:
74+
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/models/qwen3_moe.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2024 The Qwen team.
23
# Copyright 2023 The vLLM team.
3-
#
4+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
45
#
56
# Licensed under the Apache License, Version 2.0 (the "License");
67
# you may not use this file except in compliance with the License.
@@ -26,20 +27,23 @@
2627
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
2728
get_tp_group)
2829
from vllm.forward_context import get_forward_context
30+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
2931
from vllm.model_executor.layers.layernorm import RMSNorm
3032
from vllm.model_executor.layers.linear import ReplicatedLinear
3133
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3234
from vllm.model_executor.layers.quantization import QuantizationConfig
3335
from vllm.model_executor.layers.vocab_parallel_embedding import (
3436
ParallelLMHead, VocabParallelEmbedding)
37+
from vllm.model_executor.models.interfaces import (MixtureOfExperts,
38+
SupportsLoRA, SupportsPP)
3539
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
3640
Qwen3MoeDecoderLayer,
3741
Qwen3MoeForCausalLM,
3842
Qwen3MoeMLP, Qwen3MoeModel,
3943
Qwen3MoeSparseMoeBlock)
4044
from vllm.model_executor.models.utils import (
41-
extract_layer_index, make_empty_intermediate_tensors_factory, make_layers,
42-
maybe_prefix)
45+
PPMissingLayer, extract_layer_index,
46+
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
4347

4448
from vllm_ascend.ops.fused_moe import AscendFusedMoE
4549

@@ -230,6 +234,9 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
230234

231235
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
232236
nn.Module.__init__(self)
237+
SupportsPP.__init__(self)
238+
SupportsLoRA.__init__(self)
239+
MixtureOfExperts.__init__(self)
233240
config = vllm_config.model_config.hf_config
234241
quant_config = vllm_config.quant_config
235242
self.config = config
@@ -238,9 +245,31 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
238245
prefix=maybe_prefix(prefix, "model"))
239246
self.lm_head = ParallelLMHead(config.vocab_size,
240247
config.hidden_size,
241-
quant_config=quant_config)
248+
quant_config=quant_config,
249+
prefix=maybe_prefix(prefix, "lm_head"))
242250
if self.config.tie_word_embeddings:
243251
self.lm_head.weight = self.model.embed_tokens.weight
244252
self.logits_processor = LogitsProcessor(config.vocab_size)
245253
self.make_empty_intermediate_tensors = (
246254
self.model.make_empty_intermediate_tensors)
255+
256+
# Set MoE hyperparameters
257+
self.expert_weights: list[torch.Tensor] = []
258+
259+
self.moe_layers: list[FusedMoE] = []
260+
example_layer = None
261+
for layer in self.model.layers:
262+
if isinstance(layer, PPMissingLayer):
263+
continue
264+
265+
assert isinstance(layer, Qwen3MoeDecoderLayer)
266+
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
267+
example_layer = layer.mlp
268+
self.moe_layers.append(layer.mlp.experts)
269+
270+
if example_layer is None:
271+
raise RuntimeError("No Qwen3MoE layer found in the model.layers.")
272+
273+
self.num_moe_layers = len(self.moe_layers)
274+
self.num_expert_groups = 1
275+
self.num_shared_experts = 0

0 commit comments

Comments
 (0)