Skip to content

Commit 6e8d8c4

Browse files
authored
[Test] Add Unit Test for Batched DeepGEMM (#21559)
Signed-off-by: yewentao256 <[email protected]>
1 parent 8d524ce commit 6e8d8c4

File tree

3 files changed

+107
-8
lines changed

3 files changed

+107
-8
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
8+
BatchedDeepGemmExperts)
9+
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
10+
BatchedPrepareAndFinalize, BatchedTritonExperts)
11+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
12+
FusedMoEModularKernel)
13+
from vllm.utils.deep_gemm import calc_diff, is_deep_gemm_supported
14+
15+
from .test_deepgemm import make_block_quant_fp8_weights
16+
17+
BLOCK_SIZE = [128, 128]
18+
19+
20+
@pytest.mark.skipif(not is_deep_gemm_supported(),
21+
reason="Requires deep_gemm kernels")
22+
@pytest.mark.parametrize("E", [16, 32]) # number of experts
23+
@pytest.mark.parametrize("T", [256, 512]) # tokens per expert
24+
@pytest.mark.parametrize("K", [128, 256]) # hidden dim
25+
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
26+
@pytest.mark.parametrize("topk", [2, 4])
27+
def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
28+
monkeypatch):
29+
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
30+
31+
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
32+
33+
device = "cuda"
34+
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(E, N, K, BLOCK_SIZE)
35+
36+
M = E * T # total tokens
37+
a = torch.randn(M, K, device=device, dtype=torch.bfloat16) / 10.0
38+
fp8_info = torch.finfo(torch.float8_e4m3fn)
39+
a.clamp_(fp8_info.min, fp8_info.max)
40+
41+
# random router outputs → top-k indices / weights
42+
router_logits = torch.randn(M, E, device=device, dtype=torch.float32)
43+
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
44+
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
45+
46+
# token number for each expert
47+
cnt = torch.bincount(topk_ids.flatten(), minlength=E)
48+
max_cnt = int(cnt.max().item())
49+
# next power of 2 for max token number
50+
max_num_tokens = 1 << (max_cnt - 1).bit_length()
51+
52+
prep_finalize = BatchedPrepareAndFinalize(
53+
max_num_tokens=max_num_tokens,
54+
num_local_experts=E,
55+
num_dispatchers=1,
56+
rank=0,
57+
)
58+
59+
# triton (reference)
60+
triton_experts = BatchedTritonExperts(
61+
max_num_tokens=max_num_tokens,
62+
num_dispatchers=1,
63+
use_fp8_w8a8=True,
64+
per_act_token_quant=False,
65+
block_shape=BLOCK_SIZE,
66+
)
67+
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
68+
69+
out_triton = mk_triton(
70+
hidden_states=a,
71+
w1=w1,
72+
w2=w2,
73+
topk_weights=topk_weights,
74+
topk_ids=topk_ids,
75+
inplace=False,
76+
w1_scale=w1_s,
77+
w2_scale=w2_s,
78+
global_num_experts=E,
79+
)
80+
81+
# deepgemm
82+
deepgemm_experts = BatchedDeepGemmExperts(
83+
max_num_tokens=max_num_tokens,
84+
num_dispatchers=1,
85+
block_shape=BLOCK_SIZE,
86+
per_act_token_quant=False,
87+
)
88+
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
89+
90+
out_deepgemm = mk_deepgemm(
91+
hidden_states=a,
92+
w1=w1,
93+
w2=w2,
94+
topk_weights=topk_weights,
95+
topk_ids=topk_ids,
96+
inplace=False,
97+
w1_scale=w1_s,
98+
w2_scale=w2_s,
99+
global_num_experts=E,
100+
)
101+
102+
diff = calc_diff(out_deepgemm, out_triton)
103+
assert diff < 1e-3, f"Output diff too large: {diff}"

tests/kernels/moe/test_deepgemm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@
2020

2121
BLOCK_SIZE = [128, 128]
2222

23-
requires_deep_gemm = pytest.mark.skipif(
24-
not is_deep_gemm_supported(),
25-
reason="Requires deep_gemm kernels",
26-
)
27-
2823

2924
def make_block_quant_fp8_weights(
3025
e: int,
@@ -152,7 +147,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
152147
@pytest.mark.parametrize("mnk", MNKs)
153148
@pytest.mark.parametrize("topk", TOPKS)
154149
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
155-
@requires_deep_gemm
150+
@pytest.mark.skipif(not is_deep_gemm_supported(),
151+
reason="Requires deep_gemm kernels")
156152
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
157153

158154
with monkeypatch.context() as m:

vllm/utils/deep_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ def is_deep_gemm_supported() -> bool:
2323
"""Return ``True`` if DeepGEMM is supported on the current platform.
2424
Currently, only Hopper and Blackwell GPUs are supported.
2525
"""
26-
supported_arch = current_platform.is_cuda() and (
26+
is_supported_arch = current_platform.is_cuda() and (
2727
current_platform.is_device_capability(90)
2828
or current_platform.is_device_capability(100))
29-
return has_deep_gemm() and supported_arch
29+
return has_deep_gemm() and is_supported_arch
3030

3131

3232
@functools.cache

0 commit comments

Comments
 (0)