88
99
1010class HpuQwen3MoeSparseMoeBlock (UpstreamQwen3MoeSparseMoeBlock ):
11-
11+ """
12+ Override forward to handle 3D tensor input (B,S,H) -> (B*S,H)
13+ and SharedFusedMoE tuple returns.
14+ """
1215 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
1316 orig_shape = hidden_states .shape
1417 hidden_dim = orig_shape [- 1 ]
1518
16- hs = hidden_states .reshape (- 1 , hidden_dim ) # (T , H)
19+ hs = hidden_states .reshape (- 1 , hidden_dim ) # (B*S , H)
1720 num_tokens = hs .shape [0 ]
1821
19- if getattr (self , "is_sequence_parallel" , False ):
22+ is_seq_parallel = getattr (self , "is_sequence_parallel" , False )
23+
24+ if is_seq_parallel :
2025 hs = sequence_parallel_chunk (hs )
2126
2227 router_logits , _ = self .gate (hs )
23- out = self .experts (hidden_states = hs , router_logits = router_logits )
2428
25- if getattr (self , "is_sequence_parallel" , False ):
29+ # SharedFusedMoE returns (shared_out, fused_out)
30+ experts_out = self .experts (hidden_states = hs , router_logits = router_logits )
31+
32+ if isinstance (experts_out , tuple ):
33+ if len (experts_out ) != 2 :
34+ raise RuntimeError (f"unexpected experts() tuple length={ len (experts_out )} ; "
35+ "expected (shared_out, fused_out)." )
36+ shared_out , fused_out = experts_out
37+ if fused_out is None :
38+ raise RuntimeError ("experts() returned fused_out=None" )
39+ out = fused_out if shared_out is None else (shared_out + fused_out )
40+ else :
41+ # backward compatibility (FusedMoE)
42+ out = experts_out
43+
44+ if is_seq_parallel :
2645 out = tensor_model_parallel_all_gather (out , 0 )
2746 out = out [:num_tokens ]
47+ else :
48+ # from upstream : TP>1 may require a reduction here.
49+ tp_size = getattr (self , "tp_size" , 1 )
50+ if tp_size > 1 and hasattr (self .experts , "maybe_all_reduce_tensor_model_parallel" ):
51+ out = self .experts .maybe_all_reduce_tensor_model_parallel (out )
2852
2953 return out .reshape (* orig_shape [:- 1 ], hidden_dim )
3054
@@ -33,8 +57,9 @@ def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
3357 lm_model = getattr (language_model , "model" , None )
3458 layers = getattr (lm_model , "layers" , None )
3559 if layers is None :
36- return
60+ return 0
3761
62+ upgraded = 0
3863 for layer in layers :
3964 mlp = getattr (layer , "mlp" , None )
4065 if mlp is None :
@@ -46,3 +71,6 @@ def upgrade_qwen3_moe_blocks_inplace(language_model: nn.Module) -> int:
4671 if isinstance (mlp , UpstreamQwen3MoeSparseMoeBlock ):
4772 mlp .__class__ = HpuQwen3MoeSparseMoeBlock
4873 mlp ._hpu_accept_3d_installed = True
74+ upgraded += 1
75+
76+ return upgraded
0 commit comments