-
Couldn't load subscription status.
- Fork 14
add op moe_align_block_size & batched_moe_align_block_size #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add op moe_align_block_size & batched_moe_align_block_size #54
Conversation
Signed-off-by: mayuyuace <[email protected]>
Signed-off-by: mayuyuace <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds two new MOE (Mixture of Experts) operations: moe_align_block_size and batched_moe_align_block_size, which align token distribution across experts to be compatible with block sizes for matrix multiplication. The implementation is adapted from vLLM's MOE alignment kernels.
Key changes:
- Implements SYCL/XPU kernels for MOE token alignment with block size constraints
- Adds comprehensive test coverage for both regular and batched alignment scenarios
- Includes utility function for rounding up to block size multiples
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/utils.py | Adds round_up utility function for block size calculations |
| tests/test_moe_align_block_size.py | Comprehensive test suite covering various scenarios including determinism, expert mapping, and edge cases |
| tests/register_ops.py | Registers the two new MOE alignment operations with PyTorch |
| tests/ops/moe_align_block_size_ops.py | Python wrappers for the MOE alignment operations with detailed documentation |
| csrc/moe/torch_bindings.cpp | Binds the C++ implementations to PyTorch operators |
| csrc/moe/moe_ops.h | Declares the function signatures for the MOE alignment operations |
| csrc/moe/moe_align_sum_kernels.cpp | Implements the SYCL kernels for MOE token alignment operations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ): | ||
| """ | ||
| Verify that actual_sorted_ids follows the correct expert-level sorting. | ||
| The kerne limplementation may or may not preserve original token order |
Copilot
AI
Oct 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'kerne limplementation' to 'kernel implementation'.
| The kerne limplementation may or may not preserve original token order | |
| The kernel implementation may or may not preserve original token order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
| int32_t* temp_storage = static_cast<int32_t*>( | ||
| slm.template get_multi_ptr<sycl::access::decorated::no>().get()); | ||
|
|
||
| int32_t* shared_counts = temp_storage + 1024; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why 1024 here? please avoid use magic number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Temp_storage needs 1024 int32 space.
CUDA does not need to display the declaration, but SYCL needs to display the declaration of this part of SLM.
Add op moe_align_block_size & batched_moe_align_block_size.
From: https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu