|
20 | 20 | import torch.nn as nn
|
21 | 21 | import torch_npu
|
22 | 22 | from pytest_mock import MockerFixture
|
| 23 | +from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase |
23 | 24 |
|
24 | 25 | from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
25 | 26 | from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
@@ -59,6 +60,7 @@ def mock_dist_env(mocker: MockerFixture):
|
59 | 60 | patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
60 | 61 | patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
61 | 62 | patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
| 63 | + patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
62 | 64 | patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
|
63 | 65 | patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
|
64 | 66 | patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce',
|
@@ -180,6 +182,18 @@ def __init__(self, shared_experts, num_tokens):
|
180 | 182 | self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
|
181 | 183 |
|
182 | 184 |
|
| 185 | +class MockFusedMoEMethod(FusedMoEMethodBase): |
| 186 | + |
| 187 | + def create_weights(self, layer: torch.nn.Module, num_experts: int, |
| 188 | + hidden_size: int, intermediate_size_per_partition: int, |
| 189 | + params_dtype: torch.dtype, **extra_weight_attrs): |
| 190 | + pass |
| 191 | + |
| 192 | + def apply(self, hidden_states: torch.Tensor, |
| 193 | + expert_weights: torch.Tensor) -> torch.Tensor: |
| 194 | + pass |
| 195 | + |
| 196 | + |
183 | 197 | class TestAscendFusedMoe:
|
184 | 198 |
|
185 | 199 | def test_init_no_quant(self, mock_dist_env, default_moe_config):
|
@@ -213,7 +227,7 @@ def test_init_no_quant(self, mock_dist_env, default_moe_config):
|
213 | 227 |
|
214 | 228 | def test_init_with_quant(self, mock_dist_env, default_moe_config):
|
215 | 229 | mock_quant_config = MagicMock()
|
216 |
| - mock_quant_method = MagicMock() |
| 230 | + mock_quant_method = MockFusedMoEMethod() |
217 | 231 | mock_quant_config.get_quant_method.return_value = mock_quant_method
|
218 | 232 |
|
219 | 233 | moe = AscendFusedMoE(**default_moe_config,
|
|
0 commit comments