Skip to content

Commit da3fa78

Browse files
yewentao256simon-mo
authored andcommitted
[Compilation Bug] Fix Inductor Graph Output with Shape Issue (#24772)
Signed-off-by: yewentao256 <[email protected]>
1 parent bbb7003 commit da3fa78

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

vllm/model_executor/models/qwen3_moe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,9 @@ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
170170
return quant_config
171171

172172
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
173-
# NOTE: hidden_states can have either 1D or 2D shape.
174-
orig_shape = hidden_states.shape
173+
assert hidden_states.dim(
174+
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
175+
is_input_1d = hidden_states.dim() == 1
175176
hidden_dim = hidden_states.shape[-1]
176177
hidden_states = hidden_states.view(-1, hidden_dim)
177178

@@ -180,7 +181,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
180181
final_hidden_states = self.experts(hidden_states=hidden_states,
181182
router_logits=router_logits)
182183

183-
return final_hidden_states.view(orig_shape)
184+
# return to 1d if input is 1d
185+
return final_hidden_states.squeeze(0) if is_input_1d else \
186+
final_hidden_states
184187

185188

186189
class Qwen3MoeAttention(nn.Module):

0 commit comments

Comments
 (0)