77from vllm .distributed import tensor_model_parallel_all_gather
88
99
10+ # 1st changes: overide only forward func to support 3d tensor input
11+ # convert (B, S, H) -> (B*S, H)
12+ # 2nd changes: corresponding changes to upstream (8edaf385)
13+ # SharedFusedMoE support
1014class HpuQwen3MoeSparseMoeBlock (UpstreamQwen3MoeSparseMoeBlock ):
1115
1216 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
1317 orig_shape = hidden_states .shape
1418 hidden_dim = orig_shape [- 1 ]
1519
16- hs = hidden_states .reshape (- 1 , hidden_dim ) # (T , H)
20+ hs = hidden_states .reshape (- 1 , hidden_dim ) # (B*S , H)
1721 num_tokens = hs .shape [0 ]
1822
19- if getattr (self , "is_sequence_parallel" , False ):
23+ is_seq_parallel = getattr (self , "is_sequence_parallel" , False )
24+
25+ if is_seq_parallel :
2026 hs = sequence_parallel_chunk (hs )
2127
2228 router_logits , _ = self .gate (hs )
23- out = self .experts (hidden_states = hs , router_logits = router_logits )
2429
25- if getattr (self , "is_sequence_parallel" , False ):
30+ # SharedFusedMoE returns (shared_out, fused_out)
31+ experts_out = self .experts (hidden_states = hs , router_logits = router_logits )
32+
33+ if isinstance (experts_out , tuple ):
34+ if len (experts_out ) != 2 :
35+ raise RuntimeError (f"unexpected experts() tuple length={ len (experts_out )} ; "
36+ "expected (shared_out, fused_out)." )
37+ shared_out , fused_out = experts_out
38+ if fused_out is None :
39+ raise RuntimeError ("experts() returned fused_out=None" )
40+ out = fused_out if shared_out is None else (shared_out + fused_out )
41+ else :
42+ # backward compatibility (FusedMoE)
43+ out = experts_out
44+
45+ if is_seq_parallel :
2646 out = tensor_model_parallel_all_gather (out , 0 )
2747 out = out [:num_tokens ]
48+ else :
49+ # from upstream : TP>1 may require a reduction here.
50+ tp_size = getattr (self , "tp_size" , 1 )
51+ if tp_size > 1 and hasattr (self .experts , "maybe_all_reduce_tensor_model_parallel" ):
52+ out = self .experts .maybe_all_reduce_tensor_model_parallel (out )
2853
2954 return out .reshape (* orig_shape [:- 1 ], hidden_dim )
3055
@@ -33,8 +58,9 @@ def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
3358 lm_model = getattr (language_model , "model" , None )
3459 layers = getattr (lm_model , "layers" , None )
3560 if layers is None :
36- return
61+ return 0
3762
63+ upgraded = 0
3864 for layer in layers :
3965 mlp = getattr (layer , "mlp" , None )
4066 if mlp is None :
@@ -46,3 +72,6 @@ def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
4672 if isinstance (mlp , UpstreamQwen3MoeSparseMoeBlock ):
4773 mlp .__class__ = HpuQwen3MoeSparseMoeBlock
4874 mlp ._hpu_accept_3d_installed = True
75+ upgraded += 1
76+
77+ return upgraded
0 commit comments