|
28 | 28 | AscendUnquantizedFusedMoEMethod)
|
29 | 29 | from vllm_ascend.ops.moe.experts_selector import select_experts
|
30 | 30 | from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
31 |
| -from vllm_ascend.utils import AscendSocVersion, adapt_patch |
| 31 | +from vllm_ascend.utils import AscendSocVersion, adapt_patch, vllm_version_is |
32 | 32 |
|
33 | 33 | adapt_patch(True)
|
34 | 34 |
|
@@ -93,14 +93,18 @@ def mock_finalize(hidden_states, **kwargs):
|
93 | 93 |
|
94 | 94 | mock_moe_comm_method.finalize.side_effect = mock_finalize
|
95 | 95 |
|
96 |
| - mock_forward_context_obj = MagicMock( |
97 |
| - moe_comm_method=mock_moe_comm_method, |
98 |
| - moe_comm_type=MoECommType.MC2, |
99 |
| - max_tokens_across_dp=10, |
100 |
| - dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]), |
101 |
| - mc2_mask=torch.zeros(16, dtype=torch.bool), |
102 |
| - padded_num_tokens=16, |
103 |
| - with_quant=False) |
| 96 | + if vllm_version_is("0.10.2"): |
| 97 | + dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10]) |
| 98 | + else: |
| 99 | + dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5]) |
| 100 | + mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method, |
| 101 | + moe_comm_type=MoECommType.MC2, |
| 102 | + max_tokens_across_dp=10, |
| 103 | + dp_metadata=dp_metadata, |
| 104 | + mc2_mask=torch.zeros( |
| 105 | + 16, dtype=torch.bool), |
| 106 | + padded_num_tokens=16, |
| 107 | + with_quant=False) |
104 | 108 |
|
105 | 109 | with patch('torch.distributed.get_rank', return_value=0), \
|
106 | 110 | patch('torch.distributed.get_world_size', return_value=4), \
|
|
0 commit comments