Skip to content

Commit 4f510bc

Browse files
authored
[Model] Removes redundant all-reduce operation in Qwen3MoeSparseMoeBlock (#23169)
Signed-off-by: Yizhou Liu <[email protected]>
1 parent 1298c67 commit 4f510bc

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

vllm/model_executor/models/qwen3_moe.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __init__(
139139
top_k=config.num_experts_per_tok,
140140
hidden_size=config.hidden_size,
141141
intermediate_size=config.moe_intermediate_size,
142-
reduce_results=False,
142+
reduce_results=True,
143143
renormalize=config.norm_topk_prob,
144144
quant_config=quant_config,
145145
prefix=f"{prefix}.experts",
@@ -163,10 +163,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
163163
final_hidden_states = self.experts(hidden_states=hidden_states,
164164
router_logits=router_logits)
165165

166-
if self.tp_size > 1:
167-
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
168-
final_hidden_states)
169-
170166
return final_hidden_states.view(orig_shape)
171167

172168

0 commit comments

Comments
 (0)