Skip to content

Commit 035ce61

Browse files
committed
Fix allgather error
1 parent 00ba071 commit 035ce61

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

tests/ut/ops/test_token_dispatcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ def test_token_combine_with_expert_map(self):
274274
self.dispatcher.original_shape = (3, 128)
275275
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
276276
hidden_states = torch.randn(6, 128)
277+
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1],
278+
dtype=torch.int32)
277279

278280
final_hidden_states = self.dispatcher.token_combine(hidden_states)
279281
self.assertEqual(final_hidden_states.shape, (6, 128))

vllm_ascend/ops/moe/token_dispatcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,8 @@ def token_combine(self,
381381
hidden_states: torch.Tensor,
382382
bias: torch.Tensor = None):
383383
assert self.original_shape is not None
384+
385+
self.expanded_row_idx = torch.abs(self.expanded_row_idx)
384386
final_hidden_states = torch_npu.npu_moe_token_unpermute(
385387
permuted_tokens=hidden_states,
386388
sorted_indices=self.expanded_row_idx,

0 commit comments

Comments
 (0)