Skip to content

Commit fc46e54

Browse files
committed
Fix allgather error
1 parent 00ba071 commit fc46e54

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

tests/ut/ops/test_token_dispatcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ def test_token_combine_with_expert_map(self):
274274
self.dispatcher.original_shape = (3, 128)
275275
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
276276
hidden_states = torch.randn(6, 128)
277+
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1],
278+
dtype=torch.int32)
277279

278280
final_hidden_states = self.dispatcher.token_combine(hidden_states)
279281
self.assertEqual(final_hidden_states.shape, (6, 128))

vllm_ascend/ops/moe/token_dispatcher.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,8 @@ def token_combine(self,
381381
hidden_states: torch.Tensor,
382382
bias: torch.Tensor = None):
383383
assert self.original_shape is not None
384+
385+
self.expanded_row_idx = torch.abs(self.expanded_row_idx)
384386
final_hidden_states = torch_npu.npu_moe_token_unpermute(
385387
permuted_tokens=hidden_states,
386388
sorted_indices=self.expanded_row_idx,
@@ -468,9 +470,6 @@ def __init__(self, **kwargs):
468470
super().__init__(**kwargs)
469471
self.with_quant = False
470472
self.num_local_experts = kwargs.get("num_local_experts", 0)
471-
self.num_global_redundant_experts = kwargs.get(
472-
"num_global_redundant_experts", 0)
473-
self.num_experts = self.num_experts + self.num_global_redundant_experts
474473

475474
self.hidden_shape = None
476475
self.topk_weights = None

0 commit comments

Comments
 (0)