Skip to content

Commit ad10837

Browse files
authored
[CI][Quickfix] Fix AscendFusedMoE init error (#2268)
### What this PR does / why we need it? Fix AscendFusedMoE init error. Use `super().__init__()` instead of `super(FusedMoE, self).__init__()` to ensure the member variables in base class could be called by the children class ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new existing test. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@766bc81 --------- Signed-off-by: MengqingCao <[email protected]>
1 parent dceef08 commit ad10837

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

tests/ut/ops/test_fused_ops.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.nn as nn
2121
import torch_npu
2222
from pytest_mock import MockerFixture
23+
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
2324

2425
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
2526
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
@@ -59,6 +60,7 @@ def mock_dist_env(mocker: MockerFixture):
5960
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
6061
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
6162
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
63+
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
6264
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
6365
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
6466
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce',
@@ -180,6 +182,18 @@ def __init__(self, shared_experts, num_tokens):
180182
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
181183

182184

185+
class MockFusedMoEMethod(FusedMoEMethodBase):
186+
187+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
188+
hidden_size: int, intermediate_size_per_partition: int,
189+
params_dtype: torch.dtype, **extra_weight_attrs):
190+
pass
191+
192+
def apply(self, hidden_states: torch.Tensor,
193+
expert_weights: torch.Tensor) -> torch.Tensor:
194+
pass
195+
196+
183197
class TestAscendFusedMoe:
184198

185199
def test_init_no_quant(self, mock_dist_env, default_moe_config):
@@ -213,7 +227,7 @@ def test_init_no_quant(self, mock_dist_env, default_moe_config):
213227

214228
def test_init_with_quant(self, mock_dist_env, default_moe_config):
215229
mock_quant_config = MagicMock()
216-
mock_quant_method = MagicMock()
230+
mock_quant_method = MockFusedMoEMethod()
217231
mock_quant_config.get_quant_method.return_value = mock_quant_method
218232

219233
moe = AscendFusedMoE(**default_moe_config,

vllm_ascend/ops/fused_moe.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,8 +1181,27 @@ def __init__(
11811181
):
11821182
# TODO: This could not initialize FusedMoE baseclass,
11831183
# fixme and make __init__() of AscendFusedMoE more clear
1184-
super(FusedMoE, self).__init__()
1185-
1184+
super().__init__(
1185+
num_experts=num_experts,
1186+
top_k=top_k,
1187+
hidden_size=hidden_size,
1188+
intermediate_size=intermediate_size,
1189+
params_dtype=params_dtype,
1190+
reduce_results=reduce_results,
1191+
renormalize=renormalize,
1192+
use_grouped_topk=use_grouped_topk,
1193+
num_expert_group=num_expert_group,
1194+
topk_group=topk_group,
1195+
quant_config=quant_config,
1196+
tp_size=tp_size,
1197+
ep_size=ep_size,
1198+
dp_size=dp_size,
1199+
prefix=prefix,
1200+
custom_routing_function=custom_routing_function,
1201+
scoring_func=scoring_func,
1202+
e_score_correction_bias=e_score_correction_bias,
1203+
activation=activation,
1204+
)
11861205
AscendFusedMoE.moe_counter += 1
11871206
self.moe_instance_id = AscendFusedMoE.moe_counter
11881207

0 commit comments

Comments
 (0)