Skip to content

Commit 3c68bf1

Browse files
committed
Fix qwen3 vl moe execution failure
Signed-off-by: Seunghyuk Park <separk@habana.ai>
1 parent 1e012ec commit 3c68bf1

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

vllm_gaudi/models/qwen3_moe.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,48 @@
88

99

1010
class HpuQwen3MoeSparseMoeBlock(UpstreamQwen3MoeSparseMoeBlock):
11+
"""
12+
Override forward to handle 3D tensor input (B,S,H) -> (B*S,H)
13+
and SharedFusedMoE tuple returns.
14+
"""
1115

1216
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1317
orig_shape = hidden_states.shape
1418
hidden_dim = orig_shape[-1]
1519

16-
hs = hidden_states.reshape(-1, hidden_dim) # (T, H)
20+
hs = hidden_states.reshape(-1, hidden_dim) # (B*S, H)
1721
num_tokens = hs.shape[0]
1822

19-
if getattr(self, "is_sequence_parallel", False):
23+
is_seq_parallel = getattr(self, "is_sequence_parallel", False)
24+
25+
if is_seq_parallel:
2026
hs = sequence_parallel_chunk(hs)
2127

2228
router_logits, _ = self.gate(hs)
23-
out = self.experts(hidden_states=hs, router_logits=router_logits)
2429

25-
if getattr(self, "is_sequence_parallel", False):
30+
# SharedFusedMoE returns (shared_out, fused_out)
31+
experts_out = self.experts(hidden_states=hs, router_logits=router_logits)
32+
33+
if isinstance(experts_out, tuple):
34+
if len(experts_out) != 2:
35+
raise RuntimeError(f"unexpected experts() tuple length={len(experts_out)}; "
36+
"expected (shared_out, fused_out).")
37+
shared_out, fused_out = experts_out
38+
if fused_out is None:
39+
raise RuntimeError("experts() returned fused_out=None")
40+
out = fused_out if shared_out is None else (shared_out + fused_out)
41+
else:
42+
# backward compatibility (FusedMoE)
43+
out = experts_out
44+
45+
if is_seq_parallel:
2646
out = tensor_model_parallel_all_gather(out, 0)
2747
out = out[:num_tokens]
48+
else:
49+
# from upstream : TP>1 may require a reduction here.
50+
tp_size = getattr(self, "tp_size", 1)
51+
if tp_size > 1 and hasattr(self.experts, "maybe_all_reduce_tensor_model_parallel"):
52+
out = self.experts.maybe_all_reduce_tensor_model_parallel(out)
2853

2954
return out.reshape(*orig_shape[:-1], hidden_dim)
3055

@@ -33,8 +58,9 @@ def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
3358
lm_model = getattr(language_model, "model", None)
3459
layers = getattr(lm_model, "layers", None)
3560
if layers is None:
36-
return
61+
return 0
3762

63+
upgraded = 0
3864
for layer in layers:
3965
mlp = getattr(layer, "mlp", None)
4066
if mlp is None:
@@ -46,3 +72,6 @@ def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
4672
if isinstance(mlp, UpstreamQwen3MoeSparseMoeBlock):
4773
mlp.__class__ = HpuQwen3MoeSparseMoeBlock
4874
mlp._hpu_accept_3d_installed = True
75+
upgraded += 1
76+
77+
return upgraded

vllm_gaudi/models/qwen3_vl_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1313
super().__init__(vllm_config=vllm_config, prefix=prefix)
1414

1515
quant_config = getattr(self, "quant_config", None)
16-
multimodal_config = getattr(vllm_config.model_config, "multimodal_config", None)
1716

1817
if hasattr(self, "visual") and self.visual is not None:
1918
self.visual = HPUQwen3_VisionTransformer(
2019
self.config.vision_config,
2120
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
2221
quant_config=quant_config,
23-
multimodal_config=multimodal_config,
2422
prefix=maybe_prefix(prefix, "visual"),
2523
)
2624

0 commit comments

Comments
 (0)