Skip to content

Commit 83092d9

Browse files
authored
[BugFix] Fix Qwen3-Next because of vllm #24982 (#3221)
- Fixes Qwen3-Next because of vllm #24982 ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? ``` def main(): prompts = [ "窗前明月光,", "The president of the United States is Mr.", "The capital of France is", "The future of AI is", "感时花溅泪,", "家书抵万金啥意思?", "plz tell me a story: ", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95) # Create an LLM. llm = LLM( model="Qwen/Qwen3-Next-80B-A3B-Instruct", tensor_parallel_size=4, enforce_eager=True, trust_remote_code=True, max_model_len=256, gpu_memory_utilization=0.7, block_size=64 ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@releases/v0.11.0 --------- Signed-off-by: Icey <[email protected]>
1 parent c73dd8f commit 83092d9

File tree

1 file changed

+15
-33
lines changed

1 file changed

+15
-33
lines changed

vllm_ascend/models/qwen3_next.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,10 @@
5151
from vllm.transformers_utils.configs import Qwen3NextConfig
5252
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
5353

54-
from vllm.model_executor.models.qwen3_next import Qwen3NextAttention # isort: skip
55-
from vllm.model_executor.models.qwen3_next import Qwen3NextDecoderLayer # isort: skip
56-
from vllm.model_executor.models.qwen3_next import Qwen3NextForCausalLM # isort: skip
57-
from vllm.model_executor.models.qwen3_next import Qwen3NextGatedDeltaNet # isort: skip
58-
from vllm.model_executor.models.qwen3_next import Qwen3NextModel # isort: skip
59-
from vllm.model_executor.models.qwen3_next import Qwen3NextSparseMoeBlock # isort: skip
60-
from vllm.model_executor.models.qwen3_next import fused_gdn_gating # isort: skip
54+
from vllm.model_executor.models.qwen3_next import ( # isort: skip
55+
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
56+
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
57+
fused_gdn_gating)
6158

6259

6360
class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
@@ -429,17 +426,16 @@ class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer):
429426

430427
def __init__(
431428
self,
432-
config: Qwen3NextConfig,
429+
vllm_config: VllmConfig,
433430
layer_type: str,
434-
model_config: Optional[ModelConfig] = None,
435-
cache_config: Optional[CacheConfig] = None,
436-
quant_config: Optional[QuantizationConfig] = None,
437-
speculative_config: Optional[SpeculativeConfig] = None,
438431
prefix: str = "",
439-
enable_eplb: bool = False,
440432
) -> None:
441433
nn.Module.__init__(self)
442-
self.config = config
434+
config = vllm_config.model_config.hf_config
435+
model_config = vllm_config.model_config
436+
cache_config = vllm_config.cache_config
437+
quant_config = vllm_config.quant_config
438+
speculative_config = vllm_config.speculative_config
443439

444440
self.layer_type = layer_type
445441
self.layer_idx = extract_layer_index(prefix)
@@ -468,12 +464,8 @@ def __init__(
468464
if (self.layer_idx not in mlp_only_layers) and (
469465
config.num_experts > 0 and
470466
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
471-
self.mlp = Qwen3NextSparseMoeBlock(
472-
config=config,
473-
quant_config=quant_config,
474-
prefix=f"{prefix}.mlp",
475-
enable_eplb=enable_eplb,
476-
)
467+
self.mlp = Qwen3NextSparseMoeBlock(vllm_config=vllm_config,
468+
prefix=f"{prefix}.mlp")
477469
else:
478470
self.mlp = Qwen3NextMLP(
479471
hidden_size=config.hidden_size,
@@ -493,14 +485,14 @@ def __init__(
493485
torch.zeros(
494486
1,
495487
1,
496-
self.config.hidden_size,
488+
config.hidden_size,
497489
dtype=config.torch_dtype,
498490
), )
499491
self.ffn_layer_scale = torch.nn.Parameter(
500492
torch.zeros(
501493
1,
502494
1,
503-
self.config.hidden_size,
495+
config.hidden_size,
504496
dtype=config.torch_dtype,
505497
), )
506498

@@ -511,13 +503,8 @@ class CustomQwen3NextModel(Qwen3NextModel):
511503
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
512504
nn.Module.__init__(self)
513505
config: Qwen3NextConfig = vllm_config.model_config.hf_config
514-
model_config = vllm_config.model_config
515-
cache_config = vllm_config.cache_config
516-
quant_config = vllm_config.quant_config
517506
parallel_config = vllm_config.parallel_config
518507
lora_config = vllm_config.lora_config
519-
speculative_config = vllm_config.speculative_config
520-
enable_eplb = parallel_config.enable_eplb
521508
eplb_config = parallel_config.eplb_config
522509
self.num_redundant_experts = eplb_config.num_redundant_experts
523510

@@ -534,14 +521,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
534521

535522
def get_layer(prefix: str):
536523
return CustomQwen3NextDecoderLayer(
537-
config,
524+
vllm_config,
538525
layer_type=config.layer_types[extract_layer_index(prefix)],
539-
model_config=model_config,
540-
cache_config=cache_config,
541-
quant_config=quant_config,
542-
speculative_config=speculative_config,
543526
prefix=prefix,
544-
enable_eplb=enable_eplb,
545527
)
546528

547529
self.start_layer, self.end_layer, self.layers = make_layers(

0 commit comments

Comments
 (0)