|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 |
| -import itertools |
| 3 | +"""Tests for the MOE align block size function. |
| 4 | +
|
| 5 | +Run `pytest tests/kernels/moe/test_moe_align_block_size.py`. |
| 6 | +""" |
| 7 | + |
| 8 | +from typing import Optional |
4 | 9 |
|
5 | 10 | import pytest
|
6 | 11 | import torch
|
7 | 12 |
|
8 |
| -from vllm import _custom_ops as ops |
9 | 13 | from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
10 |
| - moe_align_block_size_triton) |
11 |
| - |
12 |
| - |
13 |
| -@pytest.mark.parametrize( |
14 |
| - "block_size,num_tokens,topk,num_experts", |
15 |
| - list( |
16 |
| - itertools.product( |
17 |
| - [32, 64, 128, 256], # block_size |
18 |
| - [ |
19 |
| - 1, |
20 |
| - 3, |
21 |
| - 7, |
22 |
| - 16, |
23 |
| - 256, |
24 |
| - 2256, |
25 |
| - 4096, |
26 |
| - ], # num_tokens |
27 |
| - [1, 4, 16, 64], # topk |
28 |
| - [64, 160, 256, 257, 260, 264], # num_experts |
29 |
| - )), |
30 |
| -) |
31 |
| -def test_moe_align_block_size_compare_implementations(block_size, num_tokens, |
32 |
| - topk, num_experts): |
33 |
| - topk_ids = torch.stack([ |
34 |
| - torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] |
35 |
| - for _ in range(num_tokens) |
36 |
| - ]) |
| 14 | + moe_align_block_size) |
| 15 | +from vllm.platforms import current_platform |
| 16 | +from vllm.utils import round_up |
| 17 | + |
| 18 | +NUM_TOKENS = [1, 3, 7, 16, 256, 2256, 4096] |
| 19 | +NUM_EXPERTS = [32, 160, 256, 257, 512] |
| 20 | +TOP_KS = [1, 2, 16, 32] |
| 21 | +BLOCK_SIZES = [32, 64, 128, 256] |
| 22 | +current_platform.seed_everything(0) |
| 23 | + |
| 24 | + |
| 25 | +def _group_tokens_by_expert( |
| 26 | + sorted_ids: torch.Tensor, |
| 27 | + expert_ids: torch.Tensor, |
| 28 | + block_size: int, |
| 29 | + valid_length: int, |
| 30 | + total_tokens: int, |
| 31 | +) -> dict: |
| 32 | + num_blocks = valid_length // block_size |
| 33 | + expert_tokens: dict[int, list[int]] = {} |
| 34 | + |
| 35 | + for block_idx in range(num_blocks): |
| 36 | + expert_id = expert_ids[block_idx].item() |
| 37 | + block_start = block_idx * block_size |
| 38 | + block_end = min(block_start + block_size, valid_length) |
| 39 | + |
| 40 | + block_tokens = sorted_ids[block_start:block_end] |
| 41 | + valid_tokens = block_tokens[block_tokens < total_tokens] |
| 42 | + |
| 43 | + if expert_id not in expert_tokens: |
| 44 | + expert_tokens[expert_id] = [] |
| 45 | + expert_tokens[expert_id].extend(valid_tokens.tolist()) |
| 46 | + return expert_tokens |
| 47 | + |
37 | 48 |
|
| 49 | +def _verify_expert_level_sorting( |
| 50 | + actual_sorted_ids: torch.Tensor, |
| 51 | + golden_sorted_ids: torch.Tensor, |
| 52 | + expert_ids: torch.Tensor, |
| 53 | + block_size: int, |
| 54 | + valid_length: int, |
| 55 | + total_tokens: int, |
| 56 | +): |
| 57 | + """ |
| 58 | + Verify that actual_sorted_ids follows the correct expert-level sorting. |
| 59 | + The kerne limplementation may or may not preserve original token order |
| 60 | + in topk_ids in the final sorted_ids however this does not impact quality. |
| 61 | + """ |
| 62 | + # Group tokens by expert from the golden implementation |
| 63 | + golden_expert_tokens = _group_tokens_by_expert(golden_sorted_ids, |
| 64 | + expert_ids, block_size, |
| 65 | + valid_length, total_tokens) |
| 66 | + |
| 67 | + actual_expert_tokens = _group_tokens_by_expert(actual_sorted_ids, |
| 68 | + expert_ids, block_size, |
| 69 | + valid_length, total_tokens) |
| 70 | + |
| 71 | + assert set(golden_expert_tokens.keys()) == set( |
| 72 | + actual_expert_tokens.keys()), ( |
| 73 | + f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, " |
| 74 | + f"actual={set(actual_expert_tokens.keys())}") |
| 75 | + |
| 76 | + for expert_id in golden_expert_tokens: |
| 77 | + golden_tokens = torch.tensor(golden_expert_tokens[expert_id], |
| 78 | + device=actual_sorted_ids.device) |
| 79 | + actual_tokens = torch.tensor(actual_expert_tokens[expert_id], |
| 80 | + device=actual_sorted_ids.device) |
| 81 | + assert torch.equal( |
| 82 | + torch.sort(golden_tokens)[0], |
| 83 | + torch.sort(actual_tokens)[0]), ( |
| 84 | + f"Expert {expert_id} token mismatch: " |
| 85 | + f"golden={golden_expert_tokens[expert_id]}, " |
| 86 | + f"actual={actual_expert_tokens[expert_id]}") |
| 87 | + |
| 88 | + |
| 89 | +def torch_moe_align_block_size( |
| 90 | + topk_ids: torch.Tensor, |
| 91 | + block_size: int, |
| 92 | + num_experts: int, |
| 93 | + expert_map: Optional[torch.Tensor] = None, |
| 94 | + pad_sorted_ids: bool = False, |
| 95 | +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 96 | + """ |
| 97 | + Golden torch implementation of moe_align_block_size. |
| 98 | +
|
| 99 | + This function aligns the token distribution across experts to be compatible |
| 100 | + with block size for matrix multiplication by sorting tokens by expert and |
| 101 | + padding to block boundaries. |
| 102 | + """ |
38 | 103 | max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
| 104 | + if pad_sorted_ids: |
| 105 | + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) |
| 106 | + |
| 107 | + flattened_token_indices = torch.arange(topk_ids.numel(), |
| 108 | + device=topk_ids.device, |
| 109 | + dtype=torch.int32) |
| 110 | + flattened_expert_ids = topk_ids.flatten() |
| 111 | + sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, |
| 112 | + stable=True) |
| 113 | + sorted_token_indices = flattened_token_indices[sort_indices] |
| 114 | + |
| 115 | + expert_token_counts = torch.zeros(num_experts, |
| 116 | + dtype=torch.int64, |
| 117 | + device=topk_ids.device) |
| 118 | + for expert_id in range(num_experts): |
| 119 | + mask = sorted_expert_ids == expert_id |
| 120 | + expert_token_counts[expert_id] = mask.sum() |
| 121 | + |
| 122 | + expert_padded_counts = torch.zeros(num_experts, |
| 123 | + dtype=torch.int64, |
| 124 | + device=topk_ids.device) |
| 125 | + for expert_id in range(num_experts): |
| 126 | + original_count = expert_token_counts[expert_id] |
| 127 | + if original_count > 0: |
| 128 | + expert_padded_counts[expert_id] = ( |
| 129 | + (original_count + block_size - 1) // block_size) * block_size |
39 | 130 |
|
40 |
| - sorted_ids_cuda = torch.empty((max_num_tokens_padded, ), |
41 |
| - dtype=torch.int32, |
42 |
| - device=topk_ids.device) |
43 |
| - sorted_ids_cuda.fill_(topk_ids.numel()) |
44 |
| - max_num_m_blocks = max_num_tokens_padded // block_size |
45 |
| - expert_ids_cuda = torch.zeros((max_num_m_blocks, ), |
46 |
| - dtype=torch.int32, |
47 |
| - device=topk_ids.device) |
48 |
| - num_tokens_post_pad_cuda = torch.empty((1), |
49 |
| - dtype=torch.int32, |
50 |
| - device=topk_ids.device) |
51 |
| - |
52 |
| - sorted_ids_triton = torch.empty_like(sorted_ids_cuda) |
53 |
| - sorted_ids_triton.fill_(topk_ids.numel()) |
54 |
| - expert_ids_triton = torch.zeros_like(expert_ids_cuda) |
55 |
| - num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) |
56 |
| - |
57 |
| - ops.moe_align_block_size( |
58 |
| - topk_ids, |
59 |
| - num_experts, |
| 131 | + sorted_token_ids = torch.full( |
| 132 | + (max_num_tokens_padded, ), |
| 133 | + topk_ids.numel(), |
| 134 | + dtype=torch.int32, |
| 135 | + device=topk_ids.device, |
| 136 | + ) |
| 137 | + max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size |
| 138 | + expert_ids = torch.zeros(max_num_blocks, |
| 139 | + dtype=torch.int32, |
| 140 | + device=topk_ids.device) |
| 141 | + |
| 142 | + current_pos = 0 |
| 143 | + current_block = 0 |
| 144 | + for expert_id in range(num_experts): |
| 145 | + expert_mask = sorted_expert_ids == expert_id |
| 146 | + expert_tokens = sorted_token_indices[expert_mask] |
| 147 | + num_expert_tokens = expert_tokens.shape[0] |
| 148 | + |
| 149 | + if num_expert_tokens > 0: |
| 150 | + sorted_token_ids[current_pos:current_pos + |
| 151 | + num_expert_tokens] = (expert_tokens) |
| 152 | + |
| 153 | + expert_blocks_needed = expert_padded_counts[expert_id] // block_size |
| 154 | + expert_ids[current_block:current_block + |
| 155 | + expert_blocks_needed] = (expert_id) |
| 156 | + |
| 157 | + current_pos += expert_padded_counts[expert_id] |
| 158 | + current_block += expert_blocks_needed |
| 159 | + |
| 160 | + total_padded_tokens = expert_padded_counts.sum() |
| 161 | + num_tokens_post_pad = torch.tensor([total_padded_tokens], |
| 162 | + dtype=torch.int32, |
| 163 | + device=topk_ids.device) |
| 164 | + |
| 165 | + if expert_map is not None: |
| 166 | + expert_ids = expert_map[expert_ids] |
| 167 | + return sorted_token_ids, expert_ids, num_tokens_post_pad |
| 168 | + |
| 169 | + |
| 170 | +@pytest.mark.parametrize("m", NUM_TOKENS) |
| 171 | +@pytest.mark.parametrize("topk", TOP_KS) |
| 172 | +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) |
| 173 | +@pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| 174 | +@pytest.mark.parametrize("pad_sorted_ids", [False, True]) |
| 175 | +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") |
| 176 | +def test_moe_align_block_size(m: int, topk: int, num_experts: int, |
| 177 | + block_size: int, pad_sorted_ids: bool): |
| 178 | + """Test moe_align_block_size without expert mapping""" |
| 179 | + topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32) |
| 180 | + for i in range(m): |
| 181 | + experts = torch.randperm(num_experts, device="cuda")[:topk] |
| 182 | + topk_ids[i] = experts |
| 183 | + |
| 184 | + actual_sorted_ids, actual_expert_ids, actual_num_tokens = ( |
| 185 | + moe_align_block_size( |
| 186 | + topk_ids=topk_ids, |
| 187 | + block_size=block_size, |
| 188 | + num_experts=num_experts, |
| 189 | + pad_sorted_ids=pad_sorted_ids, |
| 190 | + )) |
| 191 | + golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( |
| 192 | + torch_moe_align_block_size( |
| 193 | + topk_ids=topk_ids, |
| 194 | + block_size=block_size, |
| 195 | + num_experts=num_experts, |
| 196 | + pad_sorted_ids=pad_sorted_ids, |
| 197 | + )) |
| 198 | + |
| 199 | + torch.testing.assert_close(actual_num_tokens, |
| 200 | + golden_num_tokens, |
| 201 | + atol=0, |
| 202 | + rtol=0) |
| 203 | + torch.testing.assert_close(actual_expert_ids, |
| 204 | + golden_expert_ids, |
| 205 | + atol=0, |
| 206 | + rtol=0) |
| 207 | + |
| 208 | + # For sorted_token_ids, verify block-level correctness rather than exact |
| 209 | + # order Tokens within each expert's blocks can be in any order, but expert |
| 210 | + # regions must be correct |
| 211 | + _verify_expert_level_sorting( |
| 212 | + actual_sorted_ids, |
| 213 | + golden_sorted_ids, |
| 214 | + actual_expert_ids, |
60 | 215 | block_size,
|
61 |
| - sorted_ids_cuda, |
62 |
| - expert_ids_cuda, |
63 |
| - num_tokens_post_pad_cuda, |
| 216 | + actual_num_tokens.item(), |
| 217 | + m * topk, |
64 | 218 | )
|
65 | 219 |
|
66 |
| - moe_align_block_size_triton( |
67 |
| - topk_ids, |
68 |
| - num_experts, |
| 220 | + total_tokens = m * topk |
| 221 | + assert actual_num_tokens.item() % block_size == 0, ( |
| 222 | + "num_tokens_post_pad should be divisible by block_size") |
| 223 | + assert actual_num_tokens.item() >= total_tokens, ( |
| 224 | + "num_tokens_post_pad should be at least total_tokens") |
| 225 | + valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens] |
| 226 | + assert len(valid_tokens) == total_tokens, ( |
| 227 | + f"Should have exactly {total_tokens} valid tokens, " |
| 228 | + f"got {len(valid_tokens)}") |
| 229 | + assert (actual_expert_ids >= 0).all() and ( |
| 230 | + actual_expert_ids |
| 231 | + < num_experts).all(), "expert_ids should contain valid expert indices" |
| 232 | + |
| 233 | + |
| 234 | +@pytest.mark.parametrize("m", [16, 32]) |
| 235 | +@pytest.mark.parametrize("topk", [2, 4]) |
| 236 | +@pytest.mark.parametrize("num_experts", [8]) |
| 237 | +@pytest.mark.parametrize("block_size", [64]) |
| 238 | +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") |
| 239 | +def test_moe_align_block_size_with_expert_map(m: int, topk: int, |
| 240 | + num_experts: int, |
| 241 | + block_size: int): |
| 242 | + """Test moe_align_block_size with expert mapping (EP scenario)""" |
| 243 | + topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32) |
| 244 | + for i in range(m): |
| 245 | + experts = torch.randperm(num_experts, device="cuda")[:topk] |
| 246 | + topk_ids[i] = experts |
| 247 | + |
| 248 | + expert_map = torch.full((num_experts, ), |
| 249 | + -1, |
| 250 | + device="cuda", |
| 251 | + dtype=torch.int32) |
| 252 | + local_experts = list(range(0, num_experts, 2)) |
| 253 | + for i, expert_id in enumerate(local_experts): |
| 254 | + expert_map[expert_id] = i |
| 255 | + |
| 256 | + actual_sorted_ids, actual_expert_ids, actual_num_tokens = ( |
| 257 | + moe_align_block_size( |
| 258 | + topk_ids=topk_ids, |
| 259 | + block_size=block_size, |
| 260 | + num_experts=num_experts, |
| 261 | + expert_map=expert_map, |
| 262 | + )) |
| 263 | + golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( |
| 264 | + torch_moe_align_block_size( |
| 265 | + topk_ids=topk_ids, |
| 266 | + block_size=block_size, |
| 267 | + num_experts=num_experts, |
| 268 | + expert_map=expert_map, |
| 269 | + )) |
| 270 | + |
| 271 | + torch.testing.assert_close(actual_num_tokens, |
| 272 | + golden_num_tokens, |
| 273 | + atol=0, |
| 274 | + rtol=0) |
| 275 | + torch.testing.assert_close(actual_expert_ids, |
| 276 | + golden_expert_ids, |
| 277 | + atol=0, |
| 278 | + rtol=0) |
| 279 | + _verify_expert_level_sorting( |
| 280 | + actual_sorted_ids, |
| 281 | + golden_sorted_ids, |
| 282 | + actual_expert_ids, |
69 | 283 | block_size,
|
70 |
| - sorted_ids_triton, |
71 |
| - expert_ids_triton, |
72 |
| - num_tokens_post_pad_triton, |
| 284 | + actual_num_tokens.item(), |
| 285 | + m * topk, |
73 | 286 | )
|
74 | 287 |
|
75 |
| - assert torch.allclose(expert_ids_cuda, expert_ids_triton), ( |
76 |
| - f"Expert IDs mismatch for block_size={block_size}, " |
77 |
| - f"num_tokens={num_tokens}, topk={topk}\n" |
78 |
| - f"CUDA expert_ids: {expert_ids_cuda}\n" |
79 |
| - f"Triton expert_ids: {expert_ids_triton}") |
80 | 288 |
|
81 |
| - assert torch.allclose( |
82 |
| - num_tokens_post_pad_cuda, num_tokens_post_pad_triton), ( |
83 |
| - f"Num tokens post pad mismatch for block_size={block_size}, " |
84 |
| - f"num_tokens={num_tokens}, topk={topk}\n" |
85 |
| - f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n" |
86 |
| - f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}") |
| 289 | +def test_moe_align_block_size_deterministic(): |
| 290 | + m, topk, num_experts, block_size = 128, 2, 32, 64 |
| 291 | + |
| 292 | + torch.manual_seed(42) |
| 293 | + topk_ids = torch.randint(0, |
| 294 | + num_experts, (m, topk), |
| 295 | + device="cuda", |
| 296 | + dtype=torch.int32) |
87 | 297 |
|
| 298 | + # expect the results to be reproducible |
| 299 | + results = [] |
| 300 | + for _ in range(5): |
| 301 | + sorted_ids, expert_ids, num_tokens = moe_align_block_size( |
| 302 | + topk_ids=topk_ids, block_size=block_size, num_experts=num_experts) |
| 303 | + results.append( |
| 304 | + (sorted_ids.clone(), expert_ids.clone(), num_tokens.clone())) |
88 | 305 |
|
89 |
| -if __name__ == "__main__": |
90 |
| - pytest.main([__file__]) |
| 306 | + for i in range(1, len(results)): |
| 307 | + assert torch.equal( |
| 308 | + results[0][0], |
| 309 | + results[i][0]), ("sorted_ids should be deterministic") |
| 310 | + assert torch.equal( |
| 311 | + results[0][1], |
| 312 | + results[i][1]), ("expert_ids should be deterministic") |
| 313 | + assert torch.equal( |
| 314 | + results[0][2], |
| 315 | + results[i][2]), ("num_tokens should be deterministic") |
0 commit comments