|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | + |
| 7 | +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( |
| 8 | + BatchedDeepGemmExperts) |
| 9 | +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( |
| 10 | + BatchedPrepareAndFinalize, BatchedTritonExperts) |
| 11 | +from vllm.model_executor.layers.fused_moe.modular_kernel import ( |
| 12 | + FusedMoEModularKernel) |
| 13 | +from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported |
| 14 | + |
| 15 | +from .test_deepgemm import make_block_quant_fp8_weights |
| 16 | + |
| 17 | +BLOCK_SIZE = [128, 128] |
| 18 | + |
| 19 | + |
| 20 | +@pytest.mark.skipif(not is_deep_gemm_supported(), |
| 21 | + reason="Requires deep_gemm kernels") |
| 22 | +@pytest.mark.parametrize("E", [16, 32]) # number of experts |
| 23 | +@pytest.mark.parametrize("T", [256, 512]) # tokens per expert |
| 24 | +@pytest.mark.parametrize("K", [128, 256]) # hidden dim |
| 25 | +@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert |
| 26 | +@pytest.mark.parametrize("topk", [2, 4]) |
| 27 | +def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int, |
| 28 | + monkeypatch): |
| 29 | + """Compare BatchedDeepGemmExperts to BatchedTritonExperts.""" |
| 30 | + |
| 31 | + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1") |
| 32 | + |
| 33 | + device = "cuda" |
| 34 | + w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(E, N, K, BLOCK_SIZE) |
| 35 | + |
| 36 | + M = E * T # total tokens |
| 37 | + a = torch.randn(M, K, device=device, dtype=torch.bfloat16) / 10.0 |
| 38 | + fp8_info = torch.finfo(torch.float8_e4m3fn) |
| 39 | + a.clamp_(fp8_info.min, fp8_info.max) |
| 40 | + |
| 41 | + # random router outputs → top-k indices / weights |
| 42 | + router_logits = torch.randn(M, E, device=device, dtype=torch.float32) |
| 43 | + topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) |
| 44 | + topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) |
| 45 | + |
| 46 | + # token number for each expert |
| 47 | + cnt = torch.bincount(topk_ids.flatten(), minlength=E) |
| 48 | + max_cnt = int(cnt.max().item()) |
| 49 | + # next power of 2 for max token number |
| 50 | + max_num_tokens = 1 << (max_cnt - 1).bit_length() |
| 51 | + |
| 52 | + prep_finalize = BatchedPrepareAndFinalize( |
| 53 | + max_num_tokens=max_num_tokens, |
| 54 | + num_local_experts=E, |
| 55 | + num_dispatchers=1, |
| 56 | + rank=0, |
| 57 | + ) |
| 58 | + |
| 59 | + # triton (reference) |
| 60 | + triton_experts = BatchedTritonExperts( |
| 61 | + max_num_tokens=max_num_tokens, |
| 62 | + num_dispatchers=1, |
| 63 | + use_fp8_w8a8=True, |
| 64 | + per_act_token_quant=False, |
| 65 | + block_shape=BLOCK_SIZE, |
| 66 | + ) |
| 67 | + mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts) |
| 68 | + |
| 69 | + out_triton = mk_triton( |
| 70 | + hidden_states=a, |
| 71 | + w1=w1, |
| 72 | + w2=w2, |
| 73 | + topk_weights=topk_weights, |
| 74 | + topk_ids=topk_ids, |
| 75 | + inplace=False, |
| 76 | + w1_scale=w1_s, |
| 77 | + w2_scale=w2_s, |
| 78 | + global_num_experts=E, |
| 79 | + ) |
| 80 | + |
| 81 | + # deepgemm |
| 82 | + deepgemm_experts = BatchedDeepGemmExperts( |
| 83 | + max_num_tokens=max_num_tokens, |
| 84 | + num_dispatchers=1, |
| 85 | + block_shape=BLOCK_SIZE, |
| 86 | + per_act_token_quant=False, |
| 87 | + ) |
| 88 | + mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts) |
| 89 | + |
| 90 | + out_deepgemm = mk_deepgemm( |
| 91 | + hidden_states=a, |
| 92 | + w1=w1, |
| 93 | + w2=w2, |
| 94 | + topk_weights=topk_weights, |
| 95 | + topk_ids=topk_ids, |
| 96 | + inplace=False, |
| 97 | + w1_scale=w1_s, |
| 98 | + w2_scale=w2_s, |
| 99 | + global_num_experts=E, |
| 100 | + ) |
| 101 | + |
| 102 | + diff = calc_diff(out_deepgemm, out_triton) |
| 103 | + assert diff < 1e-3, f"Output diff too large: {diff}" |
0 commit comments