diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index c6da287bb6..4fb10aaf05 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -118,12 +118,6 @@ def test_token_dispatcher_with_all_gather( score = torch.softmax(score, dim=-1, dtype=dtype) topk_weights, topk_ids = torch.topk(score, topk) topk_ids = topk_ids.to(torch.int32) - row_idx = (torch.arange( - 0, - m * topk, - device=device, - dtype=torch.int32, - ).view(topk, -1).permute(1, 0).contiguous()) dispatcher_kwargs = { "num_experts": e, @@ -137,7 +131,6 @@ def test_token_dispatcher_with_all_gather( hidden_states=a, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) @@ -201,12 +194,6 @@ def test_token_dispatcher_with_all_gather_quant( score = torch.softmax(score, dim=-1, dtype=dtype) topk_weights, topk_ids = torch.topk(score, topk) topk_ids = topk_ids.to(torch.int32) - row_idx = (torch.arange( - 0, - m * topk, - device=device, - dtype=torch.int32, - ).view(topk, -1).permute(1, 0).contiguous()) dispatcher_kwargs = { "num_experts": e, @@ -220,7 +207,6 @@ def test_token_dispatcher_with_all_gather_quant( hidden_states=a, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, with_quant=True) @@ -295,7 +281,7 @@ def test_select_experts( mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( x) - topk_weights, topk_ids, row_idx = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=hidden_states, router_logits=router_logits, top_k=topk, @@ -316,7 +302,6 @@ def test_select_experts( assert topk_weights.shape == (m, topk) assert topk_ids.shape == (m, topk) assert topk_ids.dtype == torch.int32 - assert row_idx.shape == (m, topk) gc.collect() torch.npu.empty_cache() diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index a5bdfe225d..2f5800f182 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -454,7 +454,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env, x = torch.randn(8, 2) router_logits = torch.randn(8, 2) - topk_weights, topk_ids, _ = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=2, diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 97aea93c61..8fc2bae1c9 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -204,7 +204,6 @@ def test_fused_experts_method(self, mock_unified_apply_mlp, topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], [0.6, 0.4]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]]) - row_idx = torch.arange(4) # Make sure tensors are contiguous and have correct strides hidden_states = hidden_states.contiguous() @@ -216,7 +215,6 @@ def test_fused_experts_method(self, mock_unified_apply_mlp, w2=w2, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, activation="silu") # Verify result shape diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index cc2d30774d..bc6c8932fb 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -58,7 +58,6 @@ def setUp(self): kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128} self.dispatcher = TokenDispatcherWithMC2(**kwargs) - self.row_idx = torch.arange(10, dtype=torch.int32) def tearDown(self): self.mc2_group_patch.stop() @@ -95,7 +94,7 @@ def test_token_permutation_dispatch(self): return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch: output = self.dispatcher.token_dispatch(hidden_states, topk_weights, topk_ids, - self.row_idx, expert_map) + expert_map) mock_dispatch.assert_called_once() self.assertEqual(output["group_list_type"], 0) # group_list_type == 0 @@ -116,7 +115,6 @@ def test_token_dispatch_with_shared_experts_and_quant(self): self.dispatcher.token_dispatch(self.hidden_states, self.topk_weights, torch.randint(0, 8, (10, 1)), - self.row_idx, torch.tensor( [0, 1, 2, 3, 4, 5, 6, 7]), shared_experts=self.shared_experts) @@ -180,7 +178,6 @@ def setUp(self): torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx torch.tensor([0, 1, 0, 1, 0, 1])) - self.row_idx = torch.arange(10, dtype=torch.int32) self.patcher_npu_moe_token_unpermute = patch( 'torch_npu.npu_moe_token_unpermute') self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start( @@ -197,7 +194,7 @@ def test_token_dispatch_without_expert_map(self): topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) results = self.dispatcher.token_dispatch(hidden_states, topk_weights, - topk_ids, self.row_idx, None) + topk_ids, None) # Verify npu_moe_init_routing is called self.mock_npu_moe_init_routing_v2.assert_called_once() @@ -212,7 +209,7 @@ def test_token_dispatch_with_expert_map(self): topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) results = self.dispatcher.token_dispatch(hidden_states, topk_weights, - topk_ids, self.row_idx, None) + topk_ids, None) # Verify npu_moe_init_routing is called self.mock_npu_moe_init_routing_v2.assert_called_once() @@ -236,7 +233,7 @@ def test_token_dispatch_without_quant(self): results = self.dispatcher_quant.token_dispatch(hidden_states, topk_weights, topk_ids, - self.row_idx, None) + None) self.assertEqual(results["group_list_type"], 1) @@ -257,7 +254,6 @@ def test_token_dispatch_with_quant(self): results = self.dispatcher_quant.token_dispatch(hidden_states, topk_weights, topk_ids, - self.row_idx, None, with_quant=True) @@ -399,7 +395,6 @@ def setUp(self): num_experts=4, num_local_experts=2, with_quant=False) - self.row_idx = torch.arange(10, dtype=torch.int32) def test_token_dispatch(self): hidden_states = torch.randn(8, 16) @@ -414,7 +409,6 @@ def test_token_dispatch(self): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=self.row_idx, expert_map=expert_map) self.assertIsNotNone(result["hidden_states"]) @@ -461,7 +455,6 @@ def test_token_dispatch_with_quant(self): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=self.row_idx, expert_map=expert_map, with_quant=True) @@ -490,7 +483,6 @@ def test_token_dispatch_with_quant_no_active_tokens(self): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=self.row_idx, expert_map=expert_map, with_quant=True) @@ -513,7 +505,6 @@ def test_token_dispatch_with_log2phy(self): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=self.row_idx, expert_map=expert_map, log2phy=log2phy) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index 3f2557bebe..5a4688ce78 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -715,8 +715,16 @@ def setUp(self): self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) self.router_logits = torch.randn(self.num_tokens, self.num_experts) - - @patch('torch_npu.npu_moe_gating_top_k_softmax') + """Mock custom routing""" + self.mock_custom_routing = MagicMock() + self.mock_custom_routing.return_value = (torch.ones( + self.num_tokens, self.top_k), + torch.zeros( + self.num_tokens, + self.top_k, + dtype=torch.int32)) + + @patch('torch_npu.npu_moe_gating_top_k') def test_softmax_scoring(self, mock_topk): """Test softmax scoring function""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -730,12 +738,12 @@ def test_softmax_scoring(self, mock_topk): -1).permute(1, 0).contiguous()) - weights, ids, _ = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="softmax") + weights, ids = select_experts(hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="softmax") self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -743,12 +751,14 @@ def test_softmax_scoring(self, mock_topk): def test_sigmoid_scoring(self): """Test sigmoid scoring function""" - weights, ids, _ = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=False, - scoring_func="sigmoid") + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=False, + scoring_func="sigmoid", + custom_routing_function=self.mock_custom_routing) self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) @@ -761,7 +771,8 @@ def test_invalid_scoring_func(self): top_k=self.top_k, use_grouped_topk=False, renormalize=False, - scoring_func="invalid_func") + scoring_func="invalid_func", + custom_routing_function=self.mock_custom_routing) @patch('torch.topk') def test_grouped_topk(self, mock_topk): @@ -771,13 +782,15 @@ def test_grouped_topk(self, mock_topk): self.top_k, dtype=torch.long)) - weights, ids, _ = select_experts(hidden_states=self.hidden_states, - router_logits=self.router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=False, - topk_group=4, - num_expert_group=2) + weights, ids = select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=False, + topk_group=4, + num_expert_group=2, + custom_routing_function=self.mock_custom_routing) mock_topk.assert_called() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -791,7 +804,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): self.num_experts) e_score_correction_bias = torch.randn(self.num_experts) - weights, ids, _ = select_experts( + weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -799,7 +812,8 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): renormalize=False, topk_group=4, num_expert_group=2, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + custom_routing_function=self.mock_custom_routing) mock_grouped_topk.assert_called_once() self.assertEqual(weights.shape, (self.num_tokens, self.top_k)) @@ -814,7 +828,7 @@ def test_custom_routing_function(self): self.top_k, dtype=torch.int32)) - weights, ids, _ = select_experts( + weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -827,7 +841,7 @@ def test_custom_routing_function(self): self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.dtype, torch.int32) - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_renormalize(self, mock_topk): """Test renormalization""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -841,7 +855,7 @@ def test_renormalize(self, mock_topk): -1).permute(1, 0).contiguous()) - weights, ids, _ = select_experts( + weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -853,7 +867,7 @@ def test_renormalize(self, mock_topk): sums = weights.sum(dim=-1) self.assertTrue(torch.allclose(sums, torch.ones_like(sums))) - @patch('torch_npu.npu_moe_gating_top_k_softmax') + @patch('torch_npu.npu_moe_gating_top_k') def test_output_dtypes(self, mock_topk): """Test output dtypes""" mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k), @@ -867,7 +881,7 @@ def test_output_dtypes(self, mock_topk): -1).permute(1, 0).contiguous()) - weights, ids, _ = select_experts( + weights, ids = select_experts( hidden_states=self.hidden_states, router_logits=self.router_logits, top_k=self.top_k, @@ -876,7 +890,6 @@ def test_output_dtypes(self, mock_topk): ) self.assertEqual(weights.dtype, self.hidden_states.dtype) - self.assertEqual(ids.dtype, torch.int32) class TestNativeGroupedTopkPartialMock(TestBase): diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index ac22b69bcc..e9945774e1 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -81,7 +81,7 @@ def forward_oot( logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor: - topk_weights, topk_ids, row_idx = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -101,7 +101,6 @@ def forward_oot( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, global_num_experts=global_num_experts, expert_map=expert_map) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 97489f9ac3..6dd820bc14 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -105,7 +105,7 @@ def apply( **kwargs, ) -> torch.Tensor: - topk_weights, topk_ids, row_idx = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -132,7 +132,6 @@ def apply( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, global_num_experts=global_num_experts, expert_map=expert_map, shared_experts=shared_experts, diff --git a/vllm_ascend/ops/moe/experts_selector.py b/vllm_ascend/ops/moe/experts_selector.py index eace1644fc..22f6da1ad8 100644 --- a/vllm_ascend/ops/moe/experts_selector.py +++ b/vllm_ascend/ops/moe/experts_selector.py @@ -20,17 +20,6 @@ import torch_npu -def return_row_idx(hidden_states, top_k): - num_tokens = hidden_states.shape[0] - row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=hidden_states.device).view( - top_k, -1).permute(1, 0).contiguous()) - return row_idx - - def select_experts(hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, @@ -66,21 +55,21 @@ def select_experts(hidden_states: torch.Tensor, topk_ids: selected expert IDs of shape (num_tokens, top_k). """ - topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - renormalize=renormalize, - e_score_correction_bias=e_score_correction_bias, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - global_num_experts=global_num_experts) - - if topk_weights is None: + topk_weights, topk_ids = None, None + if custom_routing_function is None: + topk_weights, topk_ids = _select_experts_with_fusion_ops( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + num_expert_group=num_expert_group, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + global_num_experts=global_num_experts) + else: topk_weights, topk_ids = _native_select_experts( hidden_states=hidden_states, router_logits=router_logits, @@ -94,9 +83,7 @@ def select_experts(hidden_states: torch.Tensor, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts, ) - if row_idx is None: - row_idx = return_row_idx(hidden_states, top_k) - return topk_weights, topk_ids, row_idx + return topk_weights, topk_ids def _native_grouped_topk( @@ -177,37 +164,37 @@ def _select_experts_with_fusion_ops( e_score_correction_bias: Optional[torch.Tensor], topk_group: Optional[int], num_expert_group: Optional[int], - custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor=1.0, global_num_experts: int = -1): - topk_weights, topk_ids, row_idx = None, None, None - # NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern - is_deepseek_v3_r1 = global_num_experts == 256 - if is_deepseek_v3_r1: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, # topk currently 8 - bias=e_score_correction_bias, - k_group=topk_group, # fix: 4 - group_count=num_expert_group, # fix 8 - group_select_mode= - 1, # 0: the maximum in the group; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; should the third output be output - # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, - eps=float(1e-20)) - row_idx = return_row_idx(hidden_states, top_k) - if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax": - topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax( - x=router_logits, finished=None, k=top_k) - topk_ids = topk_ids.to(torch.int32) + if scoring_func == "softmax": + norm_type = 0 + topk_group = 1 + num_expert_group = 1 + else: + norm_type = 1 + if e_score_correction_bias is not None and \ + e_score_correction_bias.dtype != router_logits.dtype: + e_score_correction_bias = e_score_correction_bias.to( + router_logits.dtype) + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=top_k, + bias=e_score_correction_bias, + k_group=topk_group, + group_count=num_expert_group, + group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=norm_type, # 0: softmax; 1: sigmoid + # out_flag=False, # todo new api; should the third output be output + # y2_flag=False, # old api; should the third output be output + routed_scaling_factor=1, + eps=float(1e-20)) + if scoring_func == "softmax": topk_weights = _renormalize_topk_weights(topk_weights, renormalize) - return topk_weights, topk_ids, row_idx + return topk_weights, topk_ids def _native_select_experts( diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py index 555189e0af..6c9152f988 100644 --- a/vllm_ascend/ops/moe/moe_comm_method.py +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -89,7 +89,6 @@ def fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, use_int8_w8a8: bool = False, @@ -123,7 +122,6 @@ def fused_experts( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, expert_map=expert_map, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index b36cc44ff2..4bbcafd80f 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -61,7 +61,6 @@ def token_dispatch(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, @@ -161,7 +160,6 @@ def token_dispatch(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, @@ -318,7 +316,6 @@ def token_dispatch(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, @@ -410,7 +407,6 @@ def token_dispatch(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, @@ -511,7 +507,6 @@ def token_dispatch(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - row_idx: torch.Tensor, expert_map: Optional[torch.Tensor] = None, log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index b8bcc7831f..bd6c4a7a65 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -265,7 +265,7 @@ def apply( 1] == global_num_experts, "Number of global experts mismatch" # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - topk_weights, topk_ids, row_idx = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -297,7 +297,6 @@ def apply( w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, use_int4_w4a8=True, expert_map=expert_map, log2phy=log2phy, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index ab4987f015..06ebea7b6f 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -204,7 +204,7 @@ def apply( assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" - topk_weights, topk_ids, row_idx = select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, @@ -225,7 +225,6 @@ def apply( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, use_int8_w8a8=True, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, @@ -249,7 +248,6 @@ def apply( w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, - row_idx=row_idx, use_int8_w8a8=True, expert_map=expert_map, log2phy=log2phy,