|
| 1 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 2 | +# Copyright 2023 The vLLM team. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +# This file is a part of the vllm-ascend project. |
| 16 | + |
| 17 | +from types import SimpleNamespace |
| 18 | + |
| 19 | +import pytest |
| 20 | +import torch |
| 21 | +from transformers import PretrainedConfig |
| 22 | +from vllm import forward_context |
| 23 | + |
| 24 | +from vllm_ascend.distributed import moe_comm_method |
| 25 | +from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, |
| 26 | + NativeAllGatherCommImpl) |
| 27 | + |
| 28 | + |
| 29 | +@pytest.mark.parametrize("num_tokens", [16, 128]) |
| 30 | +@pytest.mark.parametrize("hidden_size", [64, 128]) |
| 31 | +@pytest.mark.parametrize("global_num_experts", [8, 16]) |
| 32 | +@pytest.mark.parametrize("top_k_num", [2, 4]) |
| 33 | +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) |
| 34 | +@pytest.mark.parametrize("num_local_experts", [4, 8]) |
| 35 | +@pytest.mark.parametrize("ep_rank", [0, 1]) |
| 36 | +def test_all_gather_comm_impl( |
| 37 | + num_tokens, |
| 38 | + hidden_size, |
| 39 | + global_num_experts, |
| 40 | + top_k_num, |
| 41 | + dtype, |
| 42 | + num_local_experts, |
| 43 | + ep_rank, |
| 44 | +): |
| 45 | + """ |
| 46 | + Tests the AllGatherCommImpl against the NativeAllGatherCommImpl. |
| 47 | +
|
| 48 | + This test compares the outputs of the NPU-optimized AllGatherCommImpl |
| 49 | + with a native PyTorch implementation (NativeAllGatherCommImpl) to ensure |
| 50 | + correctness across various configurations. |
| 51 | + """ |
| 52 | + if top_k_num > global_num_experts: |
| 53 | + pytest.skip("top_k_num cannot be greater than global_num_experts") |
| 54 | + if num_local_experts > global_num_experts: |
| 55 | + pytest.skip( |
| 56 | + "num_local_experts cannot be greater than global_num_experts") |
| 57 | + |
| 58 | + device = torch.device("npu") |
| 59 | + hf_config = PretrainedConfig( |
| 60 | + num_experts_per_tok=top_k_num, |
| 61 | + num_experts=global_num_experts, |
| 62 | + ) |
| 63 | + |
| 64 | + # Instantiate implementations |
| 65 | + native_impl = NativeAllGatherCommImpl(device, dtype, hf_config) |
| 66 | + |
| 67 | + all_gather_impl = AllGatherCommImpl(device, dtype, hf_config) |
| 68 | + |
| 69 | + # TODO: Find out if this is the correct way to mock the forward context and ep group |
| 70 | + # Mock get_forward_context to return an object with moe_comm_method |
| 71 | + forward_context._forward_context = SimpleNamespace( |
| 72 | + moe_comm_method=all_gather_impl) |
| 73 | + # Mock get_ep_group to return a fake group with the specified ep_rank |
| 74 | + fake_ep_group = SimpleNamespace(rank_in_group=ep_rank) |
| 75 | + moe_comm_method.get_ep_group = lambda: fake_ep_group |
| 76 | + |
| 77 | + # --- Input Data --- |
| 78 | + hidden_states = torch.randn(num_tokens, |
| 79 | + hidden_size, |
| 80 | + device=device, |
| 81 | + dtype=dtype) |
| 82 | + topk_ids = torch.randint(0, |
| 83 | + global_num_experts, (num_tokens, top_k_num), |
| 84 | + device=device, |
| 85 | + dtype=torch.int32) |
| 86 | + topk_weights = torch.rand(num_tokens, top_k_num, device=device).to(dtype) |
| 87 | + topk_weights = torch.nn.functional.softmax(topk_weights, dim=1) |
| 88 | + |
| 89 | + num_experts = global_num_experts |
| 90 | + |
| 91 | + expert_map = None |
| 92 | + if num_local_experts < global_num_experts: |
| 93 | + # Create a map where some experts are local and some are not |
| 94 | + expert_map = torch.full((global_num_experts, ), -1, device=device) |
| 95 | + expert_map[ep_rank * num_local_experts:(ep_rank + 1) * |
| 96 | + num_local_experts] = torch.arange(num_local_experts, |
| 97 | + device=device) |
| 98 | + num_experts = num_local_experts |
| 99 | + |
| 100 | + # --- Run Native Implementation (Golden Reference) --- |
| 101 | + native_hidden_states_out = hidden_states.clone() |
| 102 | + ( |
| 103 | + native_permuted_hidden, |
| 104 | + native_expert_tokens, |
| 105 | + _, |
| 106 | + ) = native_impl._pre_process(hidden_states, topk_ids, topk_weights, |
| 107 | + expert_map, num_experts) |
| 108 | + # Simulate MLP output |
| 109 | + native_mlp_output = torch.randn_like(native_permuted_hidden) |
| 110 | + native_impl._post_process(native_mlp_output, native_hidden_states_out) |
| 111 | + |
| 112 | + # --- Run AllGather Implementation --- |
| 113 | + all_gather_hidden_states_out = hidden_states.clone() |
| 114 | + ( |
| 115 | + all_gather_permuted_hidden, |
| 116 | + all_gather_expert_tokens, |
| 117 | + _, |
| 118 | + ) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids, |
| 119 | + topk_weights, expert_map, |
| 120 | + num_experts) |
| 121 | + |
| 122 | + # Use the same simulated MLP output for a fair comparison |
| 123 | + all_gather_mlp_output = native_mlp_output.clone() |
| 124 | + |
| 125 | + torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output, |
| 126 | + all_gather_hidden_states_out) |
| 127 | + |
| 128 | + # --- Assertions --- |
| 129 | + # Define tolerance based on dtype |
| 130 | + atol = 1e-3 if dtype == torch.float16 else 1e-2 |
| 131 | + rtol = 1e-3 if dtype == torch.float16 else 1e-2 |
| 132 | + |
| 133 | + # 1. Compare expert_tokens from pre_process |
| 134 | + assert torch.allclose(native_expert_tokens.to( |
| 135 | + all_gather_expert_tokens.device), |
| 136 | + all_gather_expert_tokens, |
| 137 | + atol=atol, |
| 138 | + rtol=rtol), "Expert tokens do not match." |
| 139 | + |
| 140 | + # 2. Compare permuted_hidden_states from pre_process |
| 141 | + num_valid_tokens = native_expert_tokens.sum() |
| 142 | + assert torch.allclose(native_permuted_hidden[:num_valid_tokens].to( |
| 143 | + all_gather_permuted_hidden.device), |
| 144 | + all_gather_permuted_hidden[:num_valid_tokens], |
| 145 | + atol=atol, |
| 146 | + rtol=rtol), "Permuted hidden states do not match." |
| 147 | + |
| 148 | + # 3. Compare final hidden_states from post_process |
| 149 | + assert torch.allclose(native_hidden_states_out.to( |
| 150 | + all_gather_hidden_states_out.device), |
| 151 | + all_gather_hidden_states_out, |
| 152 | + atol=atol, |
| 153 | + rtol=rtol), "Final hidden states do not match." |
0 commit comments