Skip to content

Commit 66160a9

Browse files
njhillsimon-mo
authored andcommitted
[BugFix] Fix Qwen3-Next PP (#24709)
Signed-off-by: Nick Hill <[email protected]>
1 parent eaca762 commit 66160a9

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

vllm/model_executor/models/qwen3_next.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Inference-only Qwen3Next model."""
44
from collections.abc import Iterable
5+
from itertools import islice
56
from typing import Optional
67

78
import torch
@@ -917,8 +918,11 @@ def get_layer(prefix: str):
917918
make_empty_intermediate_tensors_factory(
918919
["hidden_states", "residual"], config.hidden_size))
919920

920-
self.norm = Qwen3NextRMSNorm(config.hidden_size,
921-
eps=config.rms_norm_eps)
921+
if get_pp_group().is_last_rank:
922+
self.norm = Qwen3NextRMSNorm(config.hidden_size,
923+
eps=config.rms_norm_eps)
924+
else:
925+
self.norm = PPMissingLayer()
922926

923927
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
924928
return self.embed_tokens(input_ids)
@@ -941,7 +945,7 @@ def forward(
941945
hidden_states = intermediate_tensors["hidden_states"]
942946
residual = intermediate_tensors["residual"]
943947

944-
for layer in self.layers:
948+
for layer in islice(self.layers, self.start_layer, self.end_layer):
945949
hidden_states, residual = layer(
946950
positions=positions,
947951
hidden_states=hidden_states,

0 commit comments

Comments
 (0)