Skip to content

Commit 7d94577

Browse files
authored
Add torch golden impl for moe_align_block_size kernel test (#20653)
Signed-off-by: Shixian Cui <[email protected]> Co-authored-by: Shixian Cui <[email protected]>
1 parent 59f9353 commit 7d94577

File tree

1 file changed

+296
-71
lines changed

1 file changed

+296
-71
lines changed
Lines changed: 296 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,315 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import itertools
3+
"""Tests for the MOE align block size function.
4+
5+
Run `pytest tests/kernels/moe/test_moe_align_block_size.py`.
6+
"""
7+
8+
from typing import Optional
49

510
import pytest
611
import torch
712

8-
from vllm import _custom_ops as ops
913
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
10-
moe_align_block_size_triton)
11-
12-
13-
@pytest.mark.parametrize(
14-
"block_size,num_tokens,topk,num_experts",
15-
list(
16-
itertools.product(
17-
[32, 64, 128, 256], # block_size
18-
[
19-
1,
20-
3,
21-
7,
22-
16,
23-
256,
24-
2256,
25-
4096,
26-
], # num_tokens
27-
[1, 4, 16, 64], # topk
28-
[64, 160, 256, 257, 260, 264], # num_experts
29-
)),
30-
)
31-
def test_moe_align_block_size_compare_implementations(block_size, num_tokens,
32-
topk, num_experts):
33-
topk_ids = torch.stack([
34-
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
35-
for _ in range(num_tokens)
36-
])
14+
moe_align_block_size)
15+
from vllm.platforms import current_platform
16+
from vllm.utils import round_up
17+
18+
NUM_TOKENS = [1, 3, 7, 16, 256, 2256, 4096]
19+
NUM_EXPERTS = [32, 160, 256, 257, 512]
20+
TOP_KS = [1, 2, 16, 32]
21+
BLOCK_SIZES = [32, 64, 128, 256]
22+
current_platform.seed_everything(0)
23+
24+
25+
def _group_tokens_by_expert(
26+
sorted_ids: torch.Tensor,
27+
expert_ids: torch.Tensor,
28+
block_size: int,
29+
valid_length: int,
30+
total_tokens: int,
31+
) -> dict:
32+
num_blocks = valid_length // block_size
33+
expert_tokens: dict[int, list[int]] = {}
34+
35+
for block_idx in range(num_blocks):
36+
expert_id = expert_ids[block_idx].item()
37+
block_start = block_idx * block_size
38+
block_end = min(block_start + block_size, valid_length)
39+
40+
block_tokens = sorted_ids[block_start:block_end]
41+
valid_tokens = block_tokens[block_tokens < total_tokens]
42+
43+
if expert_id not in expert_tokens:
44+
expert_tokens[expert_id] = []
45+
expert_tokens[expert_id].extend(valid_tokens.tolist())
46+
return expert_tokens
47+
3748

49+
def _verify_expert_level_sorting(
50+
actual_sorted_ids: torch.Tensor,
51+
golden_sorted_ids: torch.Tensor,
52+
expert_ids: torch.Tensor,
53+
block_size: int,
54+
valid_length: int,
55+
total_tokens: int,
56+
):
57+
"""
58+
Verify that actual_sorted_ids follows the correct expert-level sorting.
59+
The kerne limplementation may or may not preserve original token order
60+
in topk_ids in the final sorted_ids however this does not impact quality.
61+
"""
62+
# Group tokens by expert from the golden implementation
63+
golden_expert_tokens = _group_tokens_by_expert(golden_sorted_ids,
64+
expert_ids, block_size,
65+
valid_length, total_tokens)
66+
67+
actual_expert_tokens = _group_tokens_by_expert(actual_sorted_ids,
68+
expert_ids, block_size,
69+
valid_length, total_tokens)
70+
71+
assert set(golden_expert_tokens.keys()) == set(
72+
actual_expert_tokens.keys()), (
73+
f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, "
74+
f"actual={set(actual_expert_tokens.keys())}")
75+
76+
for expert_id in golden_expert_tokens:
77+
golden_tokens = torch.tensor(golden_expert_tokens[expert_id],
78+
device=actual_sorted_ids.device)
79+
actual_tokens = torch.tensor(actual_expert_tokens[expert_id],
80+
device=actual_sorted_ids.device)
81+
assert torch.equal(
82+
torch.sort(golden_tokens)[0],
83+
torch.sort(actual_tokens)[0]), (
84+
f"Expert {expert_id} token mismatch: "
85+
f"golden={golden_expert_tokens[expert_id]}, "
86+
f"actual={actual_expert_tokens[expert_id]}")
87+
88+
89+
def torch_moe_align_block_size(
90+
topk_ids: torch.Tensor,
91+
block_size: int,
92+
num_experts: int,
93+
expert_map: Optional[torch.Tensor] = None,
94+
pad_sorted_ids: bool = False,
95+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
96+
"""
97+
Golden torch implementation of moe_align_block_size.
98+
99+
This function aligns the token distribution across experts to be compatible
100+
with block size for matrix multiplication by sorting tokens by expert and
101+
padding to block boundaries.
102+
"""
38103
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
104+
if pad_sorted_ids:
105+
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
106+
107+
flattened_token_indices = torch.arange(topk_ids.numel(),
108+
device=topk_ids.device,
109+
dtype=torch.int32)
110+
flattened_expert_ids = topk_ids.flatten()
111+
sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids,
112+
stable=True)
113+
sorted_token_indices = flattened_token_indices[sort_indices]
114+
115+
expert_token_counts = torch.zeros(num_experts,
116+
dtype=torch.int64,
117+
device=topk_ids.device)
118+
for expert_id in range(num_experts):
119+
mask = sorted_expert_ids == expert_id
120+
expert_token_counts[expert_id] = mask.sum()
121+
122+
expert_padded_counts = torch.zeros(num_experts,
123+
dtype=torch.int64,
124+
device=topk_ids.device)
125+
for expert_id in range(num_experts):
126+
original_count = expert_token_counts[expert_id]
127+
if original_count > 0:
128+
expert_padded_counts[expert_id] = (
129+
(original_count + block_size - 1) // block_size) * block_size
39130

40-
sorted_ids_cuda = torch.empty((max_num_tokens_padded, ),
41-
dtype=torch.int32,
42-
device=topk_ids.device)
43-
sorted_ids_cuda.fill_(topk_ids.numel())
44-
max_num_m_blocks = max_num_tokens_padded // block_size
45-
expert_ids_cuda = torch.zeros((max_num_m_blocks, ),
46-
dtype=torch.int32,
47-
device=topk_ids.device)
48-
num_tokens_post_pad_cuda = torch.empty((1),
49-
dtype=torch.int32,
50-
device=topk_ids.device)
51-
52-
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
53-
sorted_ids_triton.fill_(topk_ids.numel())
54-
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
55-
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
56-
57-
ops.moe_align_block_size(
58-
topk_ids,
59-
num_experts,
131+
sorted_token_ids = torch.full(
132+
(max_num_tokens_padded, ),
133+
topk_ids.numel(),
134+
dtype=torch.int32,
135+
device=topk_ids.device,
136+
)
137+
max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size
138+
expert_ids = torch.zeros(max_num_blocks,
139+
dtype=torch.int32,
140+
device=topk_ids.device)
141+
142+
current_pos = 0
143+
current_block = 0
144+
for expert_id in range(num_experts):
145+
expert_mask = sorted_expert_ids == expert_id
146+
expert_tokens = sorted_token_indices[expert_mask]
147+
num_expert_tokens = expert_tokens.shape[0]
148+
149+
if num_expert_tokens > 0:
150+
sorted_token_ids[current_pos:current_pos +
151+
num_expert_tokens] = (expert_tokens)
152+
153+
expert_blocks_needed = expert_padded_counts[expert_id] // block_size
154+
expert_ids[current_block:current_block +
155+
expert_blocks_needed] = (expert_id)
156+
157+
current_pos += expert_padded_counts[expert_id]
158+
current_block += expert_blocks_needed
159+
160+
total_padded_tokens = expert_padded_counts.sum()
161+
num_tokens_post_pad = torch.tensor([total_padded_tokens],
162+
dtype=torch.int32,
163+
device=topk_ids.device)
164+
165+
if expert_map is not None:
166+
expert_ids = expert_map[expert_ids]
167+
return sorted_token_ids, expert_ids, num_tokens_post_pad
168+
169+
170+
@pytest.mark.parametrize("m", NUM_TOKENS)
171+
@pytest.mark.parametrize("topk", TOP_KS)
172+
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
173+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
174+
@pytest.mark.parametrize("pad_sorted_ids", [False, True])
175+
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
176+
def test_moe_align_block_size(m: int, topk: int, num_experts: int,
177+
block_size: int, pad_sorted_ids: bool):
178+
"""Test moe_align_block_size without expert mapping"""
179+
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
180+
for i in range(m):
181+
experts = torch.randperm(num_experts, device="cuda")[:topk]
182+
topk_ids[i] = experts
183+
184+
actual_sorted_ids, actual_expert_ids, actual_num_tokens = (
185+
moe_align_block_size(
186+
topk_ids=topk_ids,
187+
block_size=block_size,
188+
num_experts=num_experts,
189+
pad_sorted_ids=pad_sorted_ids,
190+
))
191+
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
192+
torch_moe_align_block_size(
193+
topk_ids=topk_ids,
194+
block_size=block_size,
195+
num_experts=num_experts,
196+
pad_sorted_ids=pad_sorted_ids,
197+
))
198+
199+
torch.testing.assert_close(actual_num_tokens,
200+
golden_num_tokens,
201+
atol=0,
202+
rtol=0)
203+
torch.testing.assert_close(actual_expert_ids,
204+
golden_expert_ids,
205+
atol=0,
206+
rtol=0)
207+
208+
# For sorted_token_ids, verify block-level correctness rather than exact
209+
# order Tokens within each expert's blocks can be in any order, but expert
210+
# regions must be correct
211+
_verify_expert_level_sorting(
212+
actual_sorted_ids,
213+
golden_sorted_ids,
214+
actual_expert_ids,
60215
block_size,
61-
sorted_ids_cuda,
62-
expert_ids_cuda,
63-
num_tokens_post_pad_cuda,
216+
actual_num_tokens.item(),
217+
m * topk,
64218
)
65219

66-
moe_align_block_size_triton(
67-
topk_ids,
68-
num_experts,
220+
total_tokens = m * topk
221+
assert actual_num_tokens.item() % block_size == 0, (
222+
"num_tokens_post_pad should be divisible by block_size")
223+
assert actual_num_tokens.item() >= total_tokens, (
224+
"num_tokens_post_pad should be at least total_tokens")
225+
valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens]
226+
assert len(valid_tokens) == total_tokens, (
227+
f"Should have exactly {total_tokens} valid tokens, "
228+
f"got {len(valid_tokens)}")
229+
assert (actual_expert_ids >= 0).all() and (
230+
actual_expert_ids
231+
< num_experts).all(), "expert_ids should contain valid expert indices"
232+
233+
234+
@pytest.mark.parametrize("m", [16, 32])
235+
@pytest.mark.parametrize("topk", [2, 4])
236+
@pytest.mark.parametrize("num_experts", [8])
237+
@pytest.mark.parametrize("block_size", [64])
238+
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
239+
def test_moe_align_block_size_with_expert_map(m: int, topk: int,
240+
num_experts: int,
241+
block_size: int):
242+
"""Test moe_align_block_size with expert mapping (EP scenario)"""
243+
topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32)
244+
for i in range(m):
245+
experts = torch.randperm(num_experts, device="cuda")[:topk]
246+
topk_ids[i] = experts
247+
248+
expert_map = torch.full((num_experts, ),
249+
-1,
250+
device="cuda",
251+
dtype=torch.int32)
252+
local_experts = list(range(0, num_experts, 2))
253+
for i, expert_id in enumerate(local_experts):
254+
expert_map[expert_id] = i
255+
256+
actual_sorted_ids, actual_expert_ids, actual_num_tokens = (
257+
moe_align_block_size(
258+
topk_ids=topk_ids,
259+
block_size=block_size,
260+
num_experts=num_experts,
261+
expert_map=expert_map,
262+
))
263+
golden_sorted_ids, golden_expert_ids, golden_num_tokens = (
264+
torch_moe_align_block_size(
265+
topk_ids=topk_ids,
266+
block_size=block_size,
267+
num_experts=num_experts,
268+
expert_map=expert_map,
269+
))
270+
271+
torch.testing.assert_close(actual_num_tokens,
272+
golden_num_tokens,
273+
atol=0,
274+
rtol=0)
275+
torch.testing.assert_close(actual_expert_ids,
276+
golden_expert_ids,
277+
atol=0,
278+
rtol=0)
279+
_verify_expert_level_sorting(
280+
actual_sorted_ids,
281+
golden_sorted_ids,
282+
actual_expert_ids,
69283
block_size,
70-
sorted_ids_triton,
71-
expert_ids_triton,
72-
num_tokens_post_pad_triton,
284+
actual_num_tokens.item(),
285+
m * topk,
73286
)
74287

75-
assert torch.allclose(expert_ids_cuda, expert_ids_triton), (
76-
f"Expert IDs mismatch for block_size={block_size}, "
77-
f"num_tokens={num_tokens}, topk={topk}\n"
78-
f"CUDA expert_ids: {expert_ids_cuda}\n"
79-
f"Triton expert_ids: {expert_ids_triton}")
80288

81-
assert torch.allclose(
82-
num_tokens_post_pad_cuda, num_tokens_post_pad_triton), (
83-
f"Num tokens post pad mismatch for block_size={block_size}, "
84-
f"num_tokens={num_tokens}, topk={topk}\n"
85-
f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n"
86-
f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}")
289+
def test_moe_align_block_size_deterministic():
290+
m, topk, num_experts, block_size = 128, 2, 32, 64
291+
292+
torch.manual_seed(42)
293+
topk_ids = torch.randint(0,
294+
num_experts, (m, topk),
295+
device="cuda",
296+
dtype=torch.int32)
87297

298+
# expect the results to be reproducible
299+
results = []
300+
for _ in range(5):
301+
sorted_ids, expert_ids, num_tokens = moe_align_block_size(
302+
topk_ids=topk_ids, block_size=block_size, num_experts=num_experts)
303+
results.append(
304+
(sorted_ids.clone(), expert_ids.clone(), num_tokens.clone()))
88305

89-
if __name__ == "__main__":
90-
pytest.main([__file__])
306+
for i in range(1, len(results)):
307+
assert torch.equal(
308+
results[0][0],
309+
results[i][0]), ("sorted_ids should be deterministic")
310+
assert torch.equal(
311+
results[0][1],
312+
results[i][1]), ("expert_ids should be deterministic")
313+
assert torch.equal(
314+
results[0][2],
315+
results[i][2]), ("num_tokens should be deterministic")

0 commit comments

Comments
 (0)