diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index cc2d30774d..e4da789175 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -274,6 +274,8 @@ def test_token_combine_with_expert_map(self): self.dispatcher.original_shape = (3, 128) self.dispatcher.mask = torch.tensor([0, 1, 1, 0]) hidden_states = torch.randn(6, 128) + self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1], + dtype=torch.int32) final_hidden_states = self.dispatcher.token_combine(hidden_states) self.assertEqual(final_hidden_states.shape, (6, 128)) diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index b36cc44ff2..80129bb3fa 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -381,6 +381,8 @@ def token_combine(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): assert self.original_shape is not None + + self.expanded_row_idx = torch.abs(self.expanded_row_idx) final_hidden_states = torch_npu.npu_moe_token_unpermute( permuted_tokens=hidden_states, sorted_indices=self.expanded_row_idx,