diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index 21e0a4d0b0..ab673a4f0b 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -92,8 +92,15 @@ def test_fused_experts( score = torch.softmax(score, dim=-1, dtype=dtype) topk_weights, topk_ids = torch.topk(score, topk) topk_ids = topk_ids.to(torch.int32) - - output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map) + row_idx = (torch.arange( + 0, + m * topk, + device=device, + dtype=torch.int32, + ).view(topk, -1).permute(1, 0).contiguous()) + + output = fused_experts(a, w1, w2, topk_weights, topk_ids, row_idx, topk, + e_map) torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map) # TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1) @@ -148,7 +155,7 @@ def test_select_experts( mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( x) - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=hidden_states, router_logits=router_logits, top_k=topk, @@ -169,6 +176,7 @@ def test_select_experts( assert topk_weights.shape == (m, topk) assert topk_ids.shape == (m, topk) assert topk_ids.dtype == torch.int32 + assert row_idx.shape == (m, topk) @pytest.mark.parametrize("device", DEVICE) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 42370ebe51..6db32e68ad 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -405,7 +405,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env, x = torch.randn(8, 2) router_logits = torch.randn(8, 2) - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, _ = select_experts( hidden_states=x, router_logits=router_logits, top_k=2, diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 63b017c902..669f2b9159 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -719,12 +719,12 @@ def setUp(self): def test_softmax_scoring(self): """Test softmax scoring function""" - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="softmax") + weights, ids, _ = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="softmax") self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -732,12 +732,12 @@ def test_softmax_scoring(self): def test_sigmoid_scoring(self): """Test sigmoid scoring function""" - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="sigmoid") + weights, ids, _ = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="sigmoid") self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -760,13 +760,13 @@ def test_grouped_topk(self, mock_topk): self.top_k, dtype=torch.long)) - weights, ids = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=False, - topk_group=4, - num_expert_group=2) + weights, ids, _ = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2) mock_topk.assert_called() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -780,7 +780,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): self.num_experts) e_score_correction_bias = torch.randn(self.num_experts) - weights, ids = select_experts( + weights, ids, _ = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -803,7 +803,7 @@ def test_custom_routing_function(self): self.top_k, dtype=torch.int32)) - weights, ids = select_experts( + weights, ids, _ = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -824,7 +824,7 @@ def test_renormalize(self, mock_topk): self.top_k, dtype=torch.long)) - weights, _ = select_experts( + weights, ids, _ = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -844,7 +844,7 @@ def test_output_dtypes(self, mock_topk): self.top_k, dtype=torch.long)) - weights, ids = select_experts( + weights, ids, _ = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py index 59ab60487d..0e07eb107c 100644 --- a/tests/ut/quantization/test_w8a8_dynamic.py +++ b/tests/ut/quantization/test_w8a8_dynamic.py @@ -55,6 +55,12 @@ def test_fused_experts_with_all2all(self, mock_moe_init_routing, torch.randn(self.num_tokens), ) mock_moe_finalize_routing.return_value = self.placeholder + row_idx_len = self.num_tokens * 8 + row_idx = (torch.arange( + 0, + row_idx_len, + dtype=torch.int32, + ).view(8, -1).permute(1, 0).contiguous()) result = fused_experts_with_all2all( hidden_states=self.placeholder, @@ -64,6 +70,7 @@ def test_fused_experts_with_all2all(self, mock_moe_init_routing, w2_scale=self.placeholder, topk_weights=self.placeholder, topk_ids=self.placeholder, + row_idx=row_idx, top_k=8, expert_map=expert_map, ep_group=ep_group, diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 19a86a7d03..caaa67ea53 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -68,7 +68,7 @@ def forward_oot( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor: - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, _ = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 0d6dc9c3b0..47d5d603a5 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -391,6 +391,7 @@ def fused_experts_with_all2all( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + row_idx: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, ep_group: GroupCoordinator = None, @@ -401,17 +402,10 @@ def fused_experts_with_all2all( num_tokens, _ = hidden_states.shape num_experts = w1.shape[0] - device = hidden_states.device if expert_map is not None: global_num_experts = len(expert_map) local_num_experts = global_num_experts // ep_group.world_size - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -445,12 +439,6 @@ def fused_experts_with_all2all( hidden_states = hidden_states[sorted_idx] else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -524,6 +512,7 @@ def fused_experts_with_all2all_buffer( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + row_idx: torch.Tensor, top_k: int, max_model_len: int, global_batch_size: int, @@ -535,14 +524,10 @@ def fused_experts_with_all2all_buffer( hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) num_tokens, _ = hidden_states.shape - device = hidden_states.device global_num_experts = len(expert_map) local_num_experts = global_num_experts // ep_group.world_size row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, - device=device).view(top_k, - -1).permute(1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -755,6 +740,7 @@ def fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + row_idx: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, apply_router_weight_on_input: bool = False, @@ -846,12 +832,6 @@ def fused_experts( # Rearrange hidden_states sorted_hidden_states = hidden_states[sorted_token_indices] else: - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) active_num = max_num_tokens if max_num_tokens is not None else num_tokens sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, @@ -975,7 +955,7 @@ def apply( **kwargs, ) -> torch.Tensor: - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -1019,6 +999,7 @@ def apply( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, expert_map=expert_map) elif MOE_ALL2ALL_BUFFER: @@ -1028,6 +1009,7 @@ def apply( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, max_model_len=self.max_model_len, global_batch_size=self.global_batch_size, @@ -1049,6 +1031,7 @@ def apply( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, expert_map=expert_map, ep_group=get_ep_group()) diff --git a/vllm_ascend/ops/layers/experts_selector.py b/vllm_ascend/ops/layers/experts_selector.py index c906cf3442..11524ac4a0 100644 --- a/vllm_ascend/ops/layers/experts_selector.py +++ b/vllm_ascend/ops/layers/experts_selector.py @@ -20,6 +20,17 @@ import torch_npu +def return_row_idx(hidden_states, top_k): + num_tokens = hidden_states.shape[0] + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=hidden_states.device).view( + top_k, -1).permute(1, 0).contiguous()) + return row_idx + + def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -56,7 +67,8 @@ def select_experts(hidden_states: torch.Tensor, topk_ids: selected expert IDs of shape (num_tokens, top_k). """ - topk_weights, topk_ids = _select_experts_with_fusion_ops( + topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops( + hidden_states=hidden_states, router_logits=router_logits, top_k=top_k, use_grouped_topk=use_grouped_topk, @@ -83,7 +95,9 @@ def select_experts(hidden_states: torch.Tensor, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts, ) - return topk_weights, topk_ids + if row_idx is None: + row_idx = return_row_idx(hidden_states, top_k) + return topk_weights, topk_ids, row_idx def _native_grouped_topk( @@ -156,6 +170,7 @@ def _select_expert_use_group_topk( def _select_experts_with_fusion_ops( + hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, use_grouped_topk: bool, @@ -168,7 +183,7 @@ def _select_experts_with_fusion_ops( global_num_experts: int = -1, is_unquantized: bool = False): - topk_weights, topk_ids = None, None + topk_weights, topk_ids, row_idx = None, None, None # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern is_deepseek_v3_r1 = global_num_experts == 256 if is_deepseek_v3_r1: @@ -186,14 +201,14 @@ def _select_experts_with_fusion_ops( # y2_flag=False, # old api; should the third output be output routed_scaling_factor=1, eps=float(1e-20)) - + row_idx = return_row_idx(hidden_states, top_k) if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax" and is_unquantized: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax( + topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax( x=router_logits, finished=None, k=top_k) topk_ids = topk_ids.to(torch.int32) topk_weights = _renormalize_topk_weights(topk_weights, renormalize) - return topk_weights, topk_ids + return topk_weights, topk_ids, row_idx def _native_select_experts( diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index f7d838dd32..a724615522 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -268,7 +268,7 @@ def apply( 1] == global_num_experts, "Number of global experts mismatch" # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -334,6 +334,7 @@ def apply( w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, expert_map=expert_map, ep_group=self.ep_group, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 21615f3c7f..cba090b0fe 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -365,14 +365,9 @@ def fused_experts_with_mc2( return hidden_states, shared_output -def init_routing_quant(hidden_states, top_k, topk_ids, global_num_experts): +def init_routing_quant(hidden_states, top_k, topk_ids, row_idx, + global_num_experts): num_tokens, _ = hidden_states.shape - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=hidden_states.device).view( - top_k, -1).permute(1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -398,6 +393,7 @@ def fused_experts_with_all2all( w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + row_idx: torch.Tensor, top_k: int, expert_map: torch.Tensor = None, ep_group: GroupCoordinator = None, @@ -431,7 +427,7 @@ def fused_experts_with_all2all( ) else: quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant( - hidden_states, top_k, topk_ids, global_num_experts) + hidden_states, top_k, topk_ids, row_idx, global_num_experts) gather_sizes = global_expert_tokens.new_empty( global_expert_tokens.shape[0]) @@ -463,12 +459,6 @@ def fused_experts_with_all2all( expert_tokens = expert_tokens.to(torch.int64) group_list_type = 1 else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -627,6 +617,7 @@ def fused_experts(hidden_states: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + row_idx: torch.Tensor, top_k: int, expert_map: torch.Tensor = None): original_shape = hidden_states.shape @@ -677,12 +668,6 @@ def fused_experts(hidden_states: torch.Tensor, hidden_states = hidden_states[sorted_token_indices] group_list_type = 1 else: - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -903,7 +888,7 @@ def apply( assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, row_idx = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -973,6 +958,7 @@ def apply( w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, expert_map=expert_map) else: @@ -988,6 +974,7 @@ def apply( w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=row_idx, top_k=top_k, expert_map=expert_map, ep_group=self.ep_group,