Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions vllm_gaudi/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,48 @@


class HpuQwen3MoeSparseMoeBlock(UpstreamQwen3MoeSparseMoeBlock):
"""
Override forward to handle 3D tensor input (B,S,H) -> (B*S,H)
and SharedFusedMoE tuple returns.
"""

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = orig_shape[-1]

hs = hidden_states.reshape(-1, hidden_dim) # (T, H)
hs = hidden_states.reshape(-1, hidden_dim) # (B*S, H)
num_tokens = hs.shape[0]

if getattr(self, "is_sequence_parallel", False):
is_seq_parallel = getattr(self, "is_sequence_parallel", False)

if is_seq_parallel:
hs = sequence_parallel_chunk(hs)

router_logits, _ = self.gate(hs)
out = self.experts(hidden_states=hs, router_logits=router_logits)

if getattr(self, "is_sequence_parallel", False):
# SharedFusedMoE returns (shared_out, fused_out)
experts_out = self.experts(hidden_states=hs, router_logits=router_logits)

if isinstance(experts_out, tuple):
if len(experts_out) != 2:
raise RuntimeError(f"unexpected experts() tuple length={len(experts_out)}; "
"expected (shared_out, fused_out).")
shared_out, fused_out = experts_out
if fused_out is None:
raise RuntimeError("experts() returned fused_out=None")
out = fused_out if shared_out is None else (shared_out + fused_out)
else:
# backward compatibility (FusedMoE)
out = experts_out

if is_seq_parallel:
out = tensor_model_parallel_all_gather(out, 0)
out = out[:num_tokens]
else:
# from upstream : TP>1 may require a reduction here.
tp_size = getattr(self, "tp_size", 1)
if tp_size > 1 and hasattr(self.experts, "maybe_all_reduce_tensor_model_parallel"):
out = self.experts.maybe_all_reduce_tensor_model_parallel(out)

return out.reshape(*orig_shape[:-1], hidden_dim)

Expand All @@ -33,8 +58,9 @@ def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
lm_model = getattr(language_model, "model", None)
layers = getattr(lm_model, "layers", None)
if layers is None:
return
return 0

upgraded = 0
for layer in layers:
mlp = getattr(layer, "mlp", None)
if mlp is None:
Expand All @@ -46,3 +72,6 @@ def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
if isinstance(mlp, UpstreamQwen3MoeSparseMoeBlock):
mlp.__class__ = HpuQwen3MoeSparseMoeBlock
mlp._hpu_accept_3d_installed = True
upgraded += 1

return upgraded
2 changes: 0 additions & 2 deletions vllm_gaudi/models/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

quant_config = getattr(self, "quant_config", None)
multimodal_config = getattr(vllm_config.model_config, "multimodal_config", None)

if hasattr(self, "visual") and self.visual is not None:
self.visual = HPUQwen3_VisionTransformer(
self.config.vision_config,
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)

Expand Down
Loading