diff --git a/csrc/moe/moe_align_sum_kernels.cpp b/csrc/moe/moe_align_sum_kernels.cpp index ad5b2d6..bfc64a8 100644 --- a/csrc/moe/moe_align_sum_kernels.cpp +++ b/csrc/moe/moe_align_sum_kernels.cpp @@ -6,9 +6,291 @@ #include "dispatch_utils.h" #include "utils.h" +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +// Round a up to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_next_multiple_of(T a, T b) { + return a % b == 0 ? a : ((a / b) + 1) * b; +} + namespace vllm { namespace moe { +constexpr int32_t WARP_SIZE = 32; + +namespace batched_moe_align_block_size { + +// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel. +static constexpr int32_t num_threads = 1024; +static constexpr int32_t num_blocks = 1; + +class batched_moe_align_block_size_kernel { + private: + sycl::local_accessor slm; + int32_t const num_batches; + int32_t const max_tokens_per_batch; + int32_t const block_size; + int32_t const* __restrict__ batch_num_tokens; + int32_t* __restrict__ sorted_ids; + int32_t* __restrict__ block_ids; + int32_t* __restrict__ num_tokens_post_pad; + + public: + batched_moe_align_block_size_kernel( + sycl::local_accessor& slm, int32_t const num_batches, + int32_t const max_tokens_per_batch, int32_t const block_size, + int32_t const* __restrict__ batch_num_tokens, + int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids, + int32_t* __restrict__ num_tokens_post_pad) + : slm(slm), + num_batches(num_batches), + max_tokens_per_batch(max_tokens_per_batch), + block_size(block_size), + batch_num_tokens(batch_num_tokens), + sorted_ids(sorted_ids), + block_ids(block_ids), + num_tokens_post_pad(num_tokens_post_pad) {} + + void operator()(sycl::nd_item<1> item) const { + // TODO: This is a naive implementation. Could be optimized. + auto group = item.get_group(); + auto local_id_x = item.get_local_id(0); + auto local_range = item.get_local_range(0); + auto group_range = item.get_group_range(0); + + int32_t* temp_storage = static_cast( + slm.template get_multi_ptr().get()); + + size_t const batch_id = local_id_x; + size_t const stride = local_range * group_range; + int32_t const num_blocks_per_batch = + CEILDIV(max_tokens_per_batch, block_size); + int32_t const sorted_ids_size = + num_blocks_per_batch * num_batches * block_size; + int32_t const block_ids_size = sorted_ids_size / block_size; + int32_t const SENTINEL = + num_batches * max_tokens_per_batch; // To denote invalid entries. + // Initialize sorted_ids + for (size_t i = local_id_x; i < sorted_ids_size; i += stride) { + sorted_ids[i] = SENTINEL; + } + // Initialize expert_ids with -1 + for (size_t i = local_id_x; i < block_ids_size; i += stride) { + block_ids[i] = -1; + } + + int32_t b_num_tokens = 0; + if (batch_id < num_batches) { + b_num_tokens = batch_num_tokens[batch_id]; + } + int32_t const ceil_b_num_tokens = + CEILDIV(b_num_tokens, block_size) * block_size; + + // Compute prefix sum over token counts per expert + temp_storage[local_id_x] = ceil_b_num_tokens; + item.barrier(sycl::access::fence_space::local_space); + + int cumsum_val; + sycl::joint_exclusive_scan(item.get_group(), temp_storage, + temp_storage + 1024, temp_storage, 0, + sycl::plus{}); + cumsum_val = temp_storage[local_id_x]; + + bool const is_last_batch = batch_id == (num_batches - 1); + if (is_last_batch) { + *num_tokens_post_pad = cumsum_val + ceil_b_num_tokens; + } + + if (batch_id < num_batches) { + int32_t const batch_offset = batch_id * max_tokens_per_batch; + for (size_t i = 0; i < b_num_tokens; ++i) { + sorted_ids[cumsum_val + i] = batch_offset + i; + } + + int32_t const block_start = cumsum_val / block_size; + int32_t const num_blocks = ceil_b_num_tokens / block_size; + for (size_t i = 0; i < num_blocks; ++i) { + block_ids[block_start + i] = batch_id; + } + } + } +}; +} // namespace batched_moe_align_block_size + +template +class moe_align_block_size_kernel { + private: + sycl::local_accessor slm; + const scalar_t* __restrict__ topk_ids; + int32_t* __restrict__ sorted_token_ids; + int32_t* __restrict__ expert_ids; + int32_t* __restrict__ total_tokens_post_pad; + int32_t num_experts; + int32_t padded_num_experts; + int32_t experts_per_warp; + int32_t block_size; + size_t numel; + int32_t* __restrict__ cumsum; + int32_t max_num_tokens_padded; + + public: + moe_align_block_size_kernel(sycl::local_accessor& slm, + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, int32_t padded_num_experts, + int32_t experts_per_warp, int32_t block_size, + size_t numel, int32_t* __restrict__ cumsum, + int32_t max_num_tokens_padded) + : slm(slm), + topk_ids(topk_ids), + sorted_token_ids(sorted_token_ids), + expert_ids(expert_ids), + total_tokens_post_pad(total_tokens_post_pad), + num_experts(num_experts), + padded_num_experts(padded_num_experts), + experts_per_warp(experts_per_warp), + block_size(block_size), + numel(numel), + cumsum(cumsum), + max_num_tokens_padded(max_num_tokens_padded) {} + + void operator()(sycl::nd_item<1> item) const { + auto group = item.get_group(); + auto local_id_x = item.get_local_id(0); + auto local_range = item.get_local_range(0); + + int32_t* temp_storage = static_cast( + slm.template get_multi_ptr().get()); + + int32_t* shared_counts = temp_storage + 1024; + + // Initialize sorted_token_ids with numel + for (size_t it = local_id_x; it < max_num_tokens_padded; + it += local_range) { + sorted_token_ids[it] = numel; + } + + const int warp_id = local_id_x / WARP_SIZE; + const int my_expert_start = warp_id * experts_per_warp; + + for (int i = 0; i < experts_per_warp; ++i) { + if (my_expert_start + i < padded_num_experts) { + shared_counts[warp_id * experts_per_warp + i] = 0; + } + } + + item.barrier(sycl::access::fence_space::local_space); + + const size_t tid = local_id_x; + const size_t stride = local_range; + + for (size_t i = tid; i < numel; i += stride) { + int expert_id = topk_ids[i]; + if (expert_id >= num_experts) { + continue; + } + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + int idx = warp_idx * experts_per_warp + expert_offset; + sycl::atomic_ref + atomic_count(shared_counts[idx]); + atomic_count.fetch_add(1); + } + + item.barrier(sycl::access::fence_space::local_space); + + // Compute prefix sum over token counts per expert + int expert_count = 0; + int expert_id = local_id_x; + if (expert_id < num_experts) { + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; + expert_count = CEILDIV(expert_count, block_size) * block_size; + } + + temp_storage[local_id_x] = expert_count; + item.barrier(sycl::access::fence_space::local_space); + + int cumsum_val; + sycl::joint_exclusive_scan(item.get_group(), temp_storage, + temp_storage + 1024, temp_storage, 0, + sycl::plus{}); + cumsum_val = temp_storage[local_id_x]; + if (expert_id <= num_experts) { + cumsum[expert_id] = cumsum_val; + } + + if (expert_id == num_experts) { + *total_tokens_post_pad = cumsum_val; + } + + item.barrier(sycl::access::fence_space::local_space); + + if (local_id_x < num_experts) { + for (int i = cumsum[local_id_x]; i < cumsum[local_id_x + 1]; + i += block_size) { + expert_ids[i / block_size] = local_id_x; + } + } + + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[num_experts] / block_size + local_id_x; + const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); + for (size_t i = fill_start_idx; i < expert_ids_size; i += local_range) { + expert_ids[i] = 0; + } + } +}; + +template +class count_and_sort_expert_tokens_kernel { + private: + const scalar_t* __restrict__ topk_ids; + int32_t* __restrict__ sorted_token_ids; + int32_t* __restrict__ cumsum_buffer; + size_t numel; + int32_t num_experts; + + public: + count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + size_t numel, int32_t num_experts) + : topk_ids(topk_ids), + sorted_token_ids(sorted_token_ids), + cumsum_buffer(cumsum_buffer), + numel(numel), + num_experts(num_experts) {} + + void operator()(sycl::nd_item<1> item) const { + const size_t tid = item.get_global_linear_id(); + const size_t stride = item.get_global_range(0); + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + if (expert_id >= num_experts) { + continue; + } + + auto atomic_count = + sycl::atomic_ref( + *(cumsum_buffer + expert_id)); + int32_t rank_post_pad = atomic_count.fetch_add(1); + + sorted_token_ids[rank_post_pad] = i; + } + } +}; + template class moe_sum_kernel { private: @@ -34,9 +316,248 @@ class moe_sum_kernel { } }; +template +class moe_align_block_size_small_batch_expert_kernel { + private: + sycl::local_accessor slm; + const scalar_t* __restrict__ topk_ids; + int32_t* __restrict__ sorted_token_ids; + int32_t* __restrict__ expert_ids; + int32_t* __restrict__ total_tokens_post_pad; + int32_t num_experts; + int32_t block_size; + size_t numel; + int32_t max_num_tokens_padded; + + public: + moe_align_block_size_small_batch_expert_kernel( + sycl::local_accessor& slm, + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t max_num_tokens_padded) + : slm(slm), + topk_ids(topk_ids), + sorted_token_ids(sorted_token_ids), + expert_ids(expert_ids), + total_tokens_post_pad(total_tokens_post_pad), + num_experts(num_experts), + block_size(block_size), + numel(numel), + max_num_tokens_padded(max_num_tokens_padded) {} + void operator()(sycl::nd_item<1> item) const { + auto group = item.get_group(); + auto local_id_x = item.get_local_id(0); + auto local_range = item.get_local_range(0); + + // Initialize sorted_token_ids with numel + for (size_t it = local_id_x; it < max_num_tokens_padded; + it += local_range) { + sorted_token_ids[it] = numel; + } + + const size_t tid = local_id_x; + const size_t stride = local_range; + + void* slm_ptr = static_cast( + slm.template get_multi_ptr().get()); + int32_t* cumsum = reinterpret_cast(slm_ptr); + int32_t* tokens_cnts = cumsum + num_experts + 1; + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(local_id_x + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + ++tokens_cnts[(local_id_x + 1) * num_experts + topk_ids[i]]; + } + + item.barrier(sycl::access::fence_space::local_space); + + if (local_id_x < num_experts) { + tokens_cnts[local_id_x] = 0; + for (int i = 1; i <= local_range; ++i) { + tokens_cnts[i * num_experts + local_id_x] += + tokens_cnts[(i - 1) * num_experts + local_id_x]; + } + } + + item.barrier(sycl::access::fence_space::local_space); + + if (local_id_x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[local_range * num_experts + i - 1], + block_size) * + block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + item.barrier(sycl::access::fence_space::local_space); + + if (local_id_x < num_experts) { + for (int i = cumsum[local_id_x]; i < cumsum[local_id_x + 1]; + i += block_size) { + expert_ids[i / block_size] = local_id_x; + } + } + + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[num_experts] / block_size + local_id_x; + const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); + for (size_t i = fill_start_idx; i < expert_ids_size; i += local_range) { + expert_ids[i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + int32_t rank_post_pad = + tokens_cnts[local_id_x * num_experts + expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[local_id_x * num_experts + expert_id]; + } + } +}; + } // namespace moe } // namespace vllm +// taken from +// https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad) { + const auto& queue = at::xpu::getCurrentXPUStream(); + + constexpr int32_t WARP_SIZE = 32; + int64_t padded_num_experts = + ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int experts_per_warp = WARP_SIZE; + int threads = 1024; + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + + // BlockScan uses 1024 threads and assigns one thread per expert. + TORCH_CHECK(padded_num_experts < 1024, + "padded_num_experts must be less than 1024"); + + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `cumsum` tensors + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + torch::Tensor cumsum_buffer = + torch::empty({num_experts + 1}, options_int); + bool small_batch_expert_mode = + (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t threads = (int32_t)num_experts > WARP_SIZE + ? (int32_t)num_experts + : WARP_SIZE; + const int32_t shared_mem_size = + ((threads + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + sycl::range<1> grid1(1); + sycl::range<1> block1(threads); + using small_batch_expert_kernel = + vllm::moe::moe_align_block_size_small_batch_expert_kernel< + scalar_t>; + (*queue).submit([&](sycl::handler& cgh) { + sycl::local_accessor slm( + sycl::range<1>(shared_mem_size / sizeof(int32_t)), cgh); + cgh.parallel_for( + sycl::nd_range<1>(grid1 * block1, block1), + small_batch_expert_kernel( + slm, topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, + block_size, topk_ids.numel(), sorted_token_ids.size(0))); + }); + } else { + sycl::range<1> grid1(1); + sycl::range<1> block1(threads); + using align_kernel = vllm::moe::moe_align_block_size_kernel; + + size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); + size_t shared_mem_num = 1024 + num_warps * experts_per_warp; + + (*queue).submit([&](sycl::handler& cgh) { + sycl::local_accessor slm(sycl::range<1>(shared_mem_num), + cgh); + cgh.parallel_for( + sycl::nd_range<1>(grid1 * block1, block1), + align_kernel(slm, topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, padded_num_experts, experts_per_warp, + block_size, topk_ids.numel(), + cumsum_buffer.data_ptr(), + sorted_token_ids.size(0))); + }); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + sycl::range<1> grid2(actual_blocks); + sycl::range<1> block2(block_threads); + using sort_kernel = + vllm::moe::count_and_sort_expert_tokens_kernel; + + (*queue).submit([&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<1>(grid2 * block2, block2), + sort_kernel(topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), + topk_ids.numel(), num_experts)); + }); + } + }); +} + +void batched_moe_align_block_size(int64_t max_tokens_per_batch, + int64_t block_size, + torch::Tensor const& batch_num_tokens, + torch::Tensor sorted_ids, + torch::Tensor batch_ids, + torch::Tensor num_tokens_post_pad) { + namespace batched_kernel = vllm::moe::batched_moe_align_block_size; + + const auto& queue = at::xpu::getCurrentXPUStream(); + int32_t const B = batch_num_tokens.size(0); + int32_t const num_blocks_per_batch = + round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size; + int32_t const num_blocks = num_blocks_per_batch * B; + int64_t const sorted_ids_size = num_blocks * block_size; + + TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size); + TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size); + TORCH_CHECK(num_tokens_post_pad.size(0) == 1); + TORCH_CHECK(B <= batched_kernel::num_threads); + + sycl::range<1> grid(batched_kernel::num_blocks); + sycl::range<1> block(batched_kernel::num_threads); + + (*queue).submit([&](sycl::handler& cgh) { + sycl::local_accessor slm(sycl::range<1>(1024), cgh); + cgh.parallel_for( + sycl::nd_range<1>(grid * block, block), + batched_kernel::batched_moe_align_block_size_kernel( + slm, B, max_tokens_per_batch, block_size, + batch_num_tokens.data_ptr(), + sorted_ids.data_ptr(), batch_ids.data_ptr(), + num_tokens_post_pad.data_ptr())); + }); +} + void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] torch::Tensor& output) // [num_tokens, hidden_size] { diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 4914ad4..87c26d9 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -4,6 +4,18 @@ void moe_sum(torch::Tensor& input, torch::Tensor& output); +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad); + +void batched_moe_align_block_size(int64_t max_tokens_per_batch, + int64_t block_size, + torch::Tensor const& expert_num_tokens, + torch::Tensor sorted_ids, + torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad); + std::tuple grouped_topk( torch::Tensor const& scores, torch::Tensor const& scores_with_bias, int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index ea9fbc6..3e6ef33 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -7,6 +7,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def("moe_sum(Tensor input, Tensor! output) -> ()"); m.impl("moe_sum", torch::kXPU, &moe_sum); + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size. + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts," + " int block_size, Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + m.impl("moe_align_block_size", torch::kXPU, &moe_align_block_size); + + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size, but for the batched case. + m.def( + "batched_moe_align_block_size(int max_tokens_per_batch," + " int block_size, Tensor expert_num_tokens," + " Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + m.impl("batched_moe_align_block_size", torch::kXPU, + &batched_moe_align_block_size); + // Apply grouped topk routing to select experts. m.def( "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " diff --git a/tests/ops/moe_align_block_size_ops.py b/tests/ops/moe_align_block_size_ops.py new file mode 100644 index 0000000..b10a172 --- /dev/null +++ b/tests/ops/moe_align_block_size_ops.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import tests.register_ops as ops +from tests.utils import round_up + + +def moe_align_block_size( + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Note: In the case of expert_parallel, moe_align_block_size initially + considers all experts as valid and aligns all tokens appropriately. + Before the function returns it marks the experts_ids that are not in + the current GPU rank as -1 so the MoE matmuls could skip those blocks. + This requires the num_experts input arg to be the num global experts. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + - expert_map: A tensor of shape [num_experts] that maps the expert index + from the global space to the local index space of the current + expert parallel shard. If the expert is not in the current expert + parallel shard, the mapping is set to -1. + - pad_sorted_ids: A flag indicating whether the sorted_token_ids length + should be padded to a multiple of block_size, + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + max_num_m_blocks = (max_num_tokens_padded + block_size - 1) // block_size + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad + + +def batched_moe_align_block_size( + max_tokens_per_batch: int, block_size: int, expert_num_tokens: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Given num_batches, max_tokens_per_batch, block_size and the number of + valid-tokens in each batch, prepare sorted_token_ids, expert_ids and + num_tokens_post_pad. sorted_token_ids, expert_ids and num_tokens_post_pad + have the same semantics as in moe_align_block_size. + + This function is intended to be a drop in replacement for + moe_align_batch_size for the batched case. + + Parameters: + - max_tokens_per_batch (int): Number of tokens in each batch (both + valid and invalid). + - block_size (int): block_size to align the data to. + - expert_num_tokens (torch.Tensor): expert_num_tokens[i], indicates + the number of valid tokens in batch i. + + Returns: + - sorted_token_ids (torch.Tensor): Torch tensor of size + (num_batches * max_tokens_per_batch) indicating the token indices for + that block. + - expert_ids (torch.Tensor): Torch tensor of size + ceil((num_batches * max_tokens_per_batch) / block_size) indicating + what expert to use for each block. + - num_tokens_post_pad (torch.Tensor): Torch tensor of size 1 + indicating the number of valid blocks with actual data to + process. This is represented in terms of num tokens. + Example: + Let num_batches=5, max_tokens_per_batch=8, block_size=4, and + expert_num_tokens=[2, 3, 0, 6, 8]. This expert_num_tokens tensor + indicates that, + - The first 2 tokens in the 0th batch are valid and the rest 6 are + invalid (i.e. in the 2D hidden_states tensor of shape, + [num_batches * max_tokens_per_batch, K], indices 0, 1 are valid) + - The first 3 tokens in the 1st batch are valid. i.e. indices 8, 9, 10 + - 0 tokens in the 2nd batch are valid + - first 6 tokens in the 3rd batch are valid. i.e. indices, + 24, 25, 26, 27, 28, 29 + - so on ... + + In this case, + sorted_token_ids will be [0, 1, 40, 40, + 8, 9, 10, 40, + 24, 25, 26, 27, + 28, 29, 40, 40, + 32, 33, 34, 35, + 36, 37, 38, 39, + 40, 40, 40, 40, + (rest all 40, 40, 40, 40) + ...] + Here, 40 represents an invalid index. as there is no token index 40. + The gemm kernel using this sorted_token_ids is expected to skip the + gemm computation when it encounters this invalid index. + + expert_ids will be [0, 1, 3, 3, 4, 5, 5, -1, -1, (rest all -1) ...] + Here, -1 represents an invalid expert. The gemm kernel using this + expert_ids is expected to skip the gemm computation when it encounters + an expert of id -1. + + num_tokens_post_pad will be 24 as sorted_token_ids has valid entries + until 24. + """ + + B = expert_num_tokens.size(0) + device = expert_num_tokens.device + + # Round up so each batch can be split to blocks evenly. + max_num_tokens_padded = B * round_up(max_tokens_per_batch, block_size) + + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=device) + assert max_num_tokens_padded % block_size == 0 + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=device) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=device) + + ops.batched_moe_align_block_size( + max_tokens_per_batch, + block_size, + expert_num_tokens, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + + return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/tests/register_ops.py b/tests/register_ops.py index 43f8738..da09e60 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -195,6 +195,42 @@ def moe_sum(input: torch.Tensor, output: torch.Tensor) -> None: torch.ops._moe_C.moe_sum(input, output) +def moe_align_block_size( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + torch.ops._moe_C.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + ) + + +def batched_moe_align_block_size( + max_tokens_per_batch: int, + block_size: int, + expert_num_tokens: torch.Tensor, + sorted_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + torch.ops._moe_C.batched_moe_align_block_size( + max_tokens_per_batch, + block_size, + expert_num_tokens, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + + def grouped_topk(scores: torch.Tensor, scores_with_bias: torch.Tensor, num_expert_group: int, topk_group: int, topk: int, renormalize: bool, routed_scaling_factor: float): diff --git a/tests/test_moe_align_block_size.py b/tests/test_moe_align_block_size.py new file mode 100644 index 0000000..06662ec --- /dev/null +++ b/tests/test_moe_align_block_size.py @@ -0,0 +1,474 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the MOE align block size function. + +Run `pytest tests/test_moe_align_block_size.py`. +""" + +import pytest +import torch + +from tests.ops.moe_align_block_size_ops import (batched_moe_align_block_size, + moe_align_block_size) +from tests.utils import opcheck, round_up, seed_everything + +NUM_TOKENS = [1, 3, 256, 2256, 4096] +NUM_EXPERTS = [32, 160, 256, 257] +TOP_KS = [1, 2, 16, 32] +BLOCK_SIZES = [32, 128] +seed_everything(0) + + +def _group_tokens_by_expert( + sorted_ids: torch.Tensor, + expert_ids: torch.Tensor, + block_size: int, + valid_length: int, + total_tokens: int, +) -> dict: + num_blocks = valid_length // block_size + expert_tokens: dict[int, list[int]] = {} + + for block_idx in range(num_blocks): + expert_id = expert_ids[block_idx].item() + block_start = block_idx * block_size + block_end = min(block_start + block_size, valid_length) + + block_tokens = sorted_ids[block_start:block_end] + valid_tokens = block_tokens[block_tokens < total_tokens] + + if expert_id not in expert_tokens: + expert_tokens[expert_id] = [] + expert_tokens[expert_id].extend(valid_tokens.tolist()) + return expert_tokens + + +def _verify_expert_level_sorting( + actual_sorted_ids: torch.Tensor, + golden_sorted_ids: torch.Tensor, + expert_ids: torch.Tensor, + block_size: int, + valid_length: int, + total_tokens: int, +): + """ + Verify that actual_sorted_ids follows the correct expert-level sorting. + The kerne limplementation may or may not preserve original token order + in topk_ids in the final sorted_ids however this does not impact quality. + """ + # Group tokens by expert from the golden implementation + golden_expert_tokens = _group_tokens_by_expert(golden_sorted_ids, + expert_ids, block_size, + valid_length, total_tokens) + + actual_expert_tokens = _group_tokens_by_expert(actual_sorted_ids, + expert_ids, block_size, + valid_length, total_tokens) + + assert set(golden_expert_tokens.keys()) == set( + actual_expert_tokens.keys()), ( + f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, " + f"actual={set(actual_expert_tokens.keys())}") + + for expert_id in golden_expert_tokens: + golden_tokens = torch.tensor(golden_expert_tokens[expert_id], + device=actual_sorted_ids.device) + actual_tokens = torch.tensor(actual_expert_tokens[expert_id], + device=actual_sorted_ids.device) + assert torch.equal( + torch.sort(golden_tokens)[0], + torch.sort(actual_tokens)[0]), ( + f"Expert {expert_id} token mismatch: " + f"golden={golden_expert_tokens[expert_id]}, " + f"actual={actual_expert_tokens[expert_id]}") + + +def torch_moe_align_block_size( + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Golden torch implementation of moe_align_block_size. + + This function aligns the token distribution across experts to be compatible + with block size for matrix multiplication by sorting tokens by expert and + padding to block boundaries. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + + flattened_token_indices = torch.arange(topk_ids.numel(), + device=topk_ids.device, + dtype=torch.int32) + flattened_expert_ids = topk_ids.flatten() + sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, + stable=True) + sorted_token_indices = flattened_token_indices[sort_indices] + + expert_token_counts = torch.zeros(num_experts, + dtype=torch.int64, + device=topk_ids.device) + for expert_id in range(num_experts): + mask = sorted_expert_ids == expert_id + expert_token_counts[expert_id] = mask.sum() + + expert_padded_counts = torch.zeros(num_experts, + dtype=torch.int64, + device=topk_ids.device) + for expert_id in range(num_experts): + original_count = expert_token_counts[expert_id] + if original_count > 0: + expert_padded_counts[expert_id] = ( + (original_count + block_size - 1) // block_size) * block_size + + sorted_token_ids = torch.full( + (max_num_tokens_padded, ), + topk_ids.numel(), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size + expert_ids = torch.zeros(max_num_blocks, + dtype=torch.int32, + device=topk_ids.device) + + current_pos = 0 + current_block = 0 + for expert_id in range(num_experts): + expert_mask = sorted_expert_ids == expert_id + expert_tokens = sorted_token_indices[expert_mask] + num_expert_tokens = expert_tokens.shape[0] + + if num_expert_tokens > 0: + sorted_token_ids[current_pos:current_pos + + num_expert_tokens] = (expert_tokens) + + expert_blocks_needed = expert_padded_counts[expert_id] // block_size + expert_ids[current_block:current_block + + expert_blocks_needed] = expert_id + + current_pos += expert_padded_counts[expert_id] + current_block += expert_blocks_needed + + total_padded_tokens = expert_padded_counts.sum() + num_tokens_post_pad = torch.tensor([total_padded_tokens], + dtype=torch.int32, + device=topk_ids.device) + + if expert_map is not None: + expert_ids = expert_map[expert_ids] + return sorted_token_ids, expert_ids, num_tokens_post_pad + + +@pytest.mark.parametrize("m", NUM_TOKENS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("pad_sorted_ids", [False, True]) +def test_moe_align_block_size(m: int, topk: int, num_experts: int, + block_size: int, pad_sorted_ids: bool): + """Test moe_align_block_size without expert mapping""" + topk_ids = torch.zeros((m, topk), device="xpu", dtype=torch.int32) + for i in range(m): + experts = torch.randperm(num_experts, device="xpu")[:topk] + topk_ids[i] = experts + + actual_sorted_ids, actual_expert_ids, actual_num_tokens =\ + moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + pad_sorted_ids=pad_sorted_ids, + ) + golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( + torch_moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + pad_sorted_ids=pad_sorted_ids, + )) + + torch.testing.assert_close(actual_num_tokens, + golden_num_tokens, + atol=0, + rtol=0) + torch.testing.assert_close(actual_expert_ids, + golden_expert_ids, + atol=0, + rtol=0) + + # For sorted_token_ids, verify block-level correctness rather than exact + # order Tokens within each expert's blocks can be in any order, but expert + # regions must be correct + _verify_expert_level_sorting( + actual_sorted_ids, + golden_sorted_ids, + actual_expert_ids, + block_size, + actual_num_tokens.item(), + m * topk, + ) + + total_tokens = m * topk + assert actual_num_tokens.item() % block_size == 0, ( + "num_tokens_post_pad should be divisible by block_size") + assert actual_num_tokens.item() >= total_tokens, ( + "num_tokens_post_pad should be at least total_tokens") + valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens] + assert len(valid_tokens) == total_tokens, ( + f"Should have exactly {total_tokens} valid tokens," + f" got {len(valid_tokens)}") + assert (actual_expert_ids + >= 0).all() and (actual_expert_ids < num_experts).all(), ( + "expert_ids should contain valid expert indices") + + +@pytest.mark.parametrize("m", [16, 32]) +@pytest.mark.parametrize("topk", [2, 4]) +@pytest.mark.parametrize("num_experts", [8]) +@pytest.mark.parametrize("block_size", [64]) +def test_moe_align_block_size_with_expert_map(m: int, topk: int, + num_experts: int, + block_size: int): + """Test moe_align_block_size with expert mapping (EP scenario)""" + topk_ids = torch.zeros((m, topk), device="xpu", dtype=torch.int32) + for i in range(m): + experts = torch.randperm(num_experts, device="xpu")[:topk] + topk_ids[i] = experts + + expert_map = torch.full((num_experts, ), + -1, + device="xpu", + dtype=torch.int32) + local_experts = list(range(0, num_experts, 2)) + for i, expert_id in enumerate(local_experts): + expert_map[expert_id] = i + + actual_sorted_ids, actual_expert_ids, actual_num_tokens = \ + moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + expert_map=expert_map, + ) + golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( + torch_moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + expert_map=expert_map, + )) + + torch.testing.assert_close(actual_num_tokens, + golden_num_tokens, + atol=0, + rtol=0) + torch.testing.assert_close(actual_expert_ids, + golden_expert_ids, + atol=0, + rtol=0) + _verify_expert_level_sorting( + actual_sorted_ids, + golden_sorted_ids, + actual_expert_ids, + block_size, + actual_num_tokens.item(), + m * topk, + ) + + +def test_moe_align_block_size_deterministic(): + m, topk, num_experts, block_size = 128, 2, 32, 64 + + torch.manual_seed(42) + topk_ids = torch.randint(0, + num_experts, (m, topk), + device="xpu", + dtype=torch.int32) + + # expect the results to be reproducible + results = [] + for _ in range(5): + sorted_ids, expert_ids, num_tokens = moe_align_block_size( + topk_ids=topk_ids, block_size=block_size, num_experts=num_experts) + results.append( + (sorted_ids.clone(), expert_ids.clone(), num_tokens.clone())) + + for i in range(1, len(results)): + assert torch.equal( + results[0][0], + results[i][0]), ("sorted_ids should be deterministic") + assert torch.equal( + results[0][1], + results[i][1]), ("expert_ids should be deterministic") + assert torch.equal( + results[0][2], + results[i][2]), ("num_tokens should be deterministic") + + +@pytest.mark.parametrize("max_tokens_per_batch", [13, 16, 512]) +@pytest.mark.parametrize("num_experts", [8, 16, 32, 64]) +@pytest.mark.parametrize("block_size", [8, 16, 32, 64]) +@pytest.mark.parametrize("simulate_empty_batches", [False, True]) +def test_batched_moe_align_block_size( + max_tokens_per_batch: int, + num_experts: int, + block_size: int, + simulate_empty_batches: bool, +): + + def ref_outputs( + expert_num_tokens: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + E = expert_num_tokens.size(0) + + # Round up so each batch can be split to blocks evenly. + Msum = round_up(max_tokens_per_batch, block_size) * E + ref_sorted_ids = torch.empty((Msum, ), dtype=torch.int32) + ref_expert_ids = torch.empty((Msum // block_size, ), dtype=torch.int32) + ref_num_tokens_post_pad = torch.empty((1, ), dtype=torch.int32) + + # Initialize + sentinel = E * max_tokens_per_batch + ref_sorted_ids.fill_(sentinel) + ref_expert_ids.fill_(-1) + + # Fill ref_sorted_ids + i = 0 + for expert_id, expert_nt in enumerate(expert_num_tokens): + token_offset = expert_id * max_tokens_per_batch + for j in range(expert_nt): + ref_sorted_ids[i] = token_offset + j + i += 1 + # round up i to the next block_size + i = round_up(i, block_size) + + ref_num_tokens_post_pad[0] = i + + # Fill expert_ids + nt_ceil_sum = 0 + for expert_id, expert_nt in enumerate(expert_num_tokens): + expert_ids_offset = nt_ceil_sum // block_size + ceil_expert_nt = round_up(int(expert_nt.item()), block_size) + num_blocks = ceil_expert_nt // block_size + for x in range(num_blocks): + ref_expert_ids[expert_ids_offset + x] = expert_id + nt_ceil_sum += ceil_expert_nt + + return ( + ref_sorted_ids.to("xpu"), + ref_expert_ids.to("xpu"), + ref_num_tokens_post_pad.to("xpu"), + ) + + # Compute expert_num_tokens + expert_num_tokens = torch.randint( + low=0, + high=max_tokens_per_batch, + size=(num_experts, ), + device="cpu", + dtype=torch.int32, + ) + if simulate_empty_batches: + # mark half the batches to have 0 tokens + zero_batches = torch.randperm(num_experts)[:num_experts // 2] + expert_num_tokens[zero_batches] = 0 + + # ref outputs + ref_sorted_ids, ref_expert_ids, ref_num_tokens_post_pad = ref_outputs( + expert_num_tokens) + + # outputs + sorted_ids, expert_ids, num_tokens_post_pad = batched_moe_align_block_size( + max_tokens_per_batch, block_size, expert_num_tokens.to("xpu")) + + assert ref_sorted_ids.size() == sorted_ids.size(), ( + f"{ref_sorted_ids.size()} vs {sorted_ids.size()}") + assert ref_expert_ids.size() == expert_ids.size(), ( + f"{ref_expert_ids.size()} vs {expert_ids.size()}") + assert ref_num_tokens_post_pad.size() == num_tokens_post_pad.size(), ( + f"{ref_num_tokens_post_pad.size()} vs {num_tokens_post_pad.size()}") + + torch.testing.assert_close(ref_sorted_ids, sorted_ids, atol=0, rtol=0) + torch.testing.assert_close(ref_expert_ids, expert_ids, atol=0, rtol=0) + torch.testing.assert_close(ref_num_tokens_post_pad, + num_tokens_post_pad, + atol=0, + rtol=0) + + +def test_moe_align_block_size_opcheck(): + num_experts = 4 + block_size = 4 + topk_ids = torch.randint(0, + num_experts, (3, 4), + dtype=torch.int32, + device="xpu") + + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + + opcheck( + torch.ops._moe_C.moe_align_block_size, + ( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ), + ) + + +def test_batched_moe_align_block_size_opcheck(): + max_tokens_per_batch = 512 + num_experts = 4 + block_size = 16 + + expert_num_tokens = torch.randint( + low=0, + high=max_tokens_per_batch, + size=(num_experts, ), + dtype=torch.int32, + device="xpu", + ) + + max_num_tokens_padded = num_experts * max(max_tokens_per_batch, block_size) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device="xpu") + + assert max_num_tokens_padded % block_size == 0 + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device="xpu") + + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="xpu") + + opcheck( + torch.ops._moe_C.batched_moe_align_block_size, + ( + max_tokens_per_batch, + block_size, + expert_num_tokens, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ), + ) diff --git a/tests/utils.py b/tests/utils.py index 0559bac..fe4d679 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -350,4 +350,8 @@ def check_ipex_availability(): return True else: print("Warning: IPEX not available, skipping IPEX benchmarks") - return False \ No newline at end of file + return False + + +def round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y