Skip to content

Commit 2f00a93

Browse files
authored
test: added different dtype unittests for moe permute kernels (#431)
1 parent 3e09b2c commit 2f00a93

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

src/kernels/moe/permute_kernel.cu

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,25 @@
66

77
#include <cub/cub.cuh>
88

9+
// clang-format off
10+
// for exmple: n_tokens = 2, n_experts = 8, topk = 2
11+
// ____________________________________________________________________________________________________________________________
12+
// | | flatten indices | sort flatten indices | row_id_map |
13+
// | Steps | sort by (tokens, topk) | by (experts, tokens) | sort by (topk, tokens) |
14+
// |_________________|_____________________________|______________________________________|_____________________________________|
15+
// | | [n_tokens * topk] | [n_tokens * topk] => f_idx | [topk, n_tokens] => p_idx |
16+
// | Dim | | f_idx: idx in flatten indices | p_idx: idx in permuted tokens |
17+
// |_________________|_____________________________|______________________________________|_____________________________________|
18+
// | | | | |
19+
// | top0, top1 | f_idx: | 0 | 1 | 2 | 3 | | p_idx: | 0 | 1 | 2 | 3 | | idx: | 0 | 1 | 2 | 3 | |
20+
// | t0 -> [e2, e1] | experts: | 2 | 1 | 2 | 5 | | f_idx: | 1 | 0 | 2 | 3 | | p_idx: | 1 | 2 | 0 | 3 | |
21+
// | t1 -> [e2, e5] | tokens: | t0 | t1 | | tokens: | t0 | t0 | t1 | t1 | | f_idx: | 0 | 2 | 1 | 3 | |
22+
// | | | experts: | e1 | e2 | e5 | | experts: | e2 | e2 | e1 | e5 | |
23+
// | | | | tokens: | t0 | t1 | t0 | t1 | |
24+
// | | | | topk: | top0 | top1 | |
25+
// |_________________|_____________________________|______________________________________|_____________________________________|
26+
// clang-format on
27+
928
namespace llm::kernel::moe {
1029

1130
namespace {
@@ -14,7 +33,7 @@ inline T* get_ptr(torch::Tensor& t) {
1433
return reinterpret_cast<T*>(t.data_ptr());
1534
}
1635

17-
// one thread per permuted token
36+
// build a row_id_map that maps [topk, n_tokens] to the index in permuted tokens
1837
__global__ void permute_row_id_map(
1938
const int* sorted_row_id, // [n_permuted_tokens]
2039
int* row_id_map, // [topk, n_tokens]
@@ -115,6 +134,7 @@ __global__ void unpermute_kernel(
115134
}
116135
__syncthreads();
117136

137+
// TODO: use float for accumulator
118138
Fragment frag_sum;
119139
Fragment frag;
120140

src/kernels/moe/permute_kernel_test.cu

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,18 +103,22 @@ TEST_P(PermuteTest, Index) {
103103

104104
auto ref_unpermute_out = unpermute_ref(
105105
ref_permuted_tokens, ref_sorted_indices, probs, n_tokens, topk);
106-
EXPECT_TRUE(torch::allclose(unpermute_out, ref_unpermute_out));
107-
EXPECT_TRUE(torch::allclose(tokens, unpermute_out));
106+
EXPECT_TRUE(torch::allclose(
107+
unpermute_out, ref_unpermute_out, /*rtol=*/1e-2, /*atol=*/1e-2));
108+
EXPECT_TRUE(
109+
torch::allclose(tokens, unpermute_out, /*rtol=*/1e-2, /*atol=*/1e-2));
108110
}
109111

110112
INSTANTIATE_TEST_SUITE_P(
111113
Moe,
112114
PermuteTest,
113-
::testing::Combine(::testing::Values(torch::kFloat), // dtype
114-
::testing::Values(1, 2, 16), // n_tokens
115-
::testing::Values(16, 64), // dim
116-
::testing::Values(4, 8, 16), // n_experts
117-
::testing::Values(1, 2, 4) // topk
115+
::testing::Combine(::testing::Values(torch::kFloat,
116+
torch::kHalf,
117+
torch::kBFloat16), // dtype
118+
::testing::Values(1, 2, 16), // n_tokens
119+
::testing::Values(16, 64), // dim
120+
::testing::Values(4, 8, 16), // n_experts
121+
::testing::Values(1, 2, 4) // topk
118122
));
119123

120124
} // namespace llm

0 commit comments

Comments
 (0)