Skip to content

Commit fb61400

Browse files
committed
Fix qwen3 vl moe execution failure
Signed-off-by: Seunghyuk Park <separk@habana.ai>
1 parent 1e012ec commit fb61400

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

vllm_gaudi/models/qwen3_moe.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,47 @@
88

99

1010
class HpuQwen3MoeSparseMoeBlock(UpstreamQwen3MoeSparseMoeBlock):
11-
11+
"""
12+
Override forward to handle 3D tensor input (B,S,H) -> (B*S,H)
13+
and SharedFusedMoE tuple returns.
14+
"""
1215
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1316
orig_shape = hidden_states.shape
1417
hidden_dim = orig_shape[-1]
1518

16-
hs = hidden_states.reshape(-1, hidden_dim) # (T, H)
19+
hs = hidden_states.reshape(-1, hidden_dim) # (B*S, H)
1720
num_tokens = hs.shape[0]
1821

19-
if getattr(self, "is_sequence_parallel", False):
22+
is_seq_parallel = getattr(self, "is_sequence_parallel", False)
23+
24+
if is_seq_parallel:
2025
hs = sequence_parallel_chunk(hs)
2126

2227
router_logits, _ = self.gate(hs)
23-
out = self.experts(hidden_states=hs, router_logits=router_logits)
2428

25-
if getattr(self, "is_sequence_parallel", False):
29+
# SharedFusedMoE returns (shared_out, fused_out)
30+
experts_out = self.experts(hidden_states=hs, router_logits=router_logits)
31+
32+
if isinstance(experts_out, tuple):
33+
if len(experts_out) != 2:
34+
raise RuntimeError(f"unexpected experts() tuple length={len(experts_out)}; "
35+
"expected (shared_out, fused_out).")
36+
shared_out, fused_out = experts_out
37+
if fused_out is None:
38+
raise RuntimeError("experts() returned fused_out=None")
39+
out = fused_out if shared_out is None else (shared_out + fused_out)
40+
else:
41+
# backward compatibility (FusedMoE)
42+
out = experts_out
43+
44+
if is_seq_parallel:
2645
out = tensor_model_parallel_all_gather(out, 0)
2746
out = out[:num_tokens]
47+
else:
48+
# from upstream : TP>1 may require a reduction here.
49+
tp_size = getattr(self, "tp_size", 1)
50+
if tp_size > 1 and hasattr(self.experts, "maybe_all_reduce_tensor_model_parallel"):
51+
out = self.experts.maybe_all_reduce_tensor_model_parallel(out)
2852

2953
return out.reshape(*orig_shape[:-1], hidden_dim)
3054

@@ -33,8 +57,9 @@ def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
3357
lm_model = getattr(language_model, "model", None)
3458
layers = getattr(lm_model, "layers", None)
3559
if layers is None:
36-
return
60+
return 0
3761

62+
upgraded = 0
3863
for layer in layers:
3964
mlp = getattr(layer, "mlp", None)
4065
if mlp is None:
@@ -46,3 +71,6 @@ def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
4671
if isinstance(mlp, UpstreamQwen3MoeSparseMoeBlock):
4772
mlp.__class__ = HpuQwen3MoeSparseMoeBlock
4873
mlp._hpu_accept_3d_installed = True
74+
upgraded += 1
75+
76+
return upgraded

vllm_gaudi/models/qwen3_vl_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
1313
super().__init__(vllm_config=vllm_config, prefix=prefix)
1414

1515
quant_config = getattr(self, "quant_config", None)
16-
multimodal_config = getattr(vllm_config.model_config, "multimodal_config", None)
1716

1817
if hasattr(self, "visual") and self.visual is not None:
1918
self.visual = HPUQwen3_VisionTransformer(
2019
self.config.vision_config,
2120
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
2221
quant_config=quant_config,
23-
multimodal_config=multimodal_config,
2422
prefix=maybe_prefix(prefix, "visual"),
2523
)
2624

0 commit comments

Comments
 (0)