Skip to content

Commit 14d4ed5

Browse files
authored
[BugFix] Fix aclgraph accu problem in A2. (#3163)
This PR fixes accuracy problem of aclgraph on A2. The problem is introduced by PR #2980, which makes the `all_reduce` of shared_experts exposed to torch dynamo. This PR moves all the codes into forward_impl to shiled from torch dynamo. - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@17b4c66 --------- Signed-off-by: whx-sjtu <[email protected]>
1 parent c3fee66 commit 14d4ed5

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

vllm_ascend/ops/common_fused_moe.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,8 @@ def maybe_all_reduce_tensor_model_parallel(
203203
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
204204
outputs since each rank only has partial outputs.
205205
"""
206-
forward_context = get_forward_context()
207-
moe_comm_type = forward_context.moe_comm_type
208-
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
209-
return final_hidden_states
210-
else:
211-
return tensor_model_parallel_all_reduce(final_hidden_states)
206+
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
207+
final_hidden_states)
212208

213209
def forward_impl(self, hidden_states: torch.Tensor,
214210
router_logits: torch.Tensor):
@@ -333,6 +329,15 @@ def forward(
333329
hidden_states: torch.Tensor,
334330
router_logits: torch.Tensor,
335331
) -> tuple[torch.Tensor, torch.Tensor]:
332+
shared_out, fused_out = AscendFusedMoE.forward(
333+
self,
334+
hidden_states=hidden_states,
335+
router_logits=router_logits,
336+
)
337+
return shared_out, fused_out
338+
339+
def forward_impl(self, hidden_states: torch.Tensor,
340+
router_logits: torch.Tensor):
336341
# Make sure the shared experts stream begins after hidden_states are ready.
337342
if self.multistream_overlap_shared_expert:
338343
self.shared_expert_stream.wait_stream( # type: ignore
@@ -347,26 +352,15 @@ def forward(
347352
moe_comm_type = forward_context.moe_comm_type
348353
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
349354
shared_out = tensor_model_parallel_all_reduce(shared_out)
350-
351-
_, fused_out = AscendFusedMoE.forward(
355+
fused_output = AscendFusedMoE.forward_impl(
352356
self,
353357
hidden_states=hidden_states,
354358
router_logits=router_logits,
355359
)
356360
# Make sure the default stream waits for the shared experts stream to finish.
357361
if self.multistream_overlap_shared_expert:
358362
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
359-
return shared_out, fused_out
360-
361-
def forward_impl(self, hidden_states: torch.Tensor,
362-
router_logits: torch.Tensor):
363-
shared_output = torch.empty(1)
364-
fused_output = AscendFusedMoE.forward_impl(
365-
self,
366-
hidden_states=hidden_states,
367-
router_logits=router_logits,
368-
)
369-
return shared_output, fused_output
363+
return shared_out, fused_output
370364

371365

372366
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func

vllm_ascend/ops/register_custom_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.utils import direct_register_custom_op
1111

1212
import vllm_ascend.envs as envs_ascend
13+
from vllm_ascend.ascend_forward_context import MoECommType
1314

1415

1516
def _maybe_chunk_residual_impl(x: torch.Tensor,
@@ -147,6 +148,16 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
147148
return
148149

149150

151+
def _maybe_all_reduce_tensor_model_parallel_impl(
152+
final_hidden_states: torch.Tensor) -> torch.Tensor:
153+
forward_context = get_forward_context()
154+
moe_comm_type = forward_context.moe_comm_type
155+
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
156+
return final_hidden_states
157+
else:
158+
return tensor_model_parallel_all_reduce(final_hidden_states)
159+
160+
150161
direct_register_custom_op(op_name="maybe_chunk_residual",
151162
op_func=_maybe_chunk_residual_impl,
152163
fake_impl=lambda x, residual: residual,
@@ -182,3 +193,9 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
182193
fake_impl=_maybe_wait_prefetch_done_impl_fake,
183194
mutates_args=[],
184195
dispatch_key="PrivateUse1")
196+
197+
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
198+
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
199+
fake_impl=lambda x: x,
200+
mutates_args=[],
201+
dispatch_key="PrivateUse1")

0 commit comments

Comments
 (0)