[kernel] Recompilation optimization triggered by triton function para…#7647
[kernel] Recompilation optimization triggered by triton function para…#7647HarpsealCC wants to merge 0 commit intovllm-project:releases/v0.18.0from
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of Triton kernels by implementing strategic optimizations related to parameter handling. The core changes involve modifying Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors Triton kernel definitions across multiple files, primarily by expanding the do_not_specialize list for @triton.jit decorators and converting several tl.constexpr parameters to regular runtime arguments. Additionally, it simplifies the grid and block size calculation logic in reject_sample.py and fused_gdn_gating.py. However, the changes to reject_sample.py introduce critical bugs by removing a necessary import (get_vectorcore_num) and implementing an incorrect grid/block size calculation, which will lead to indexing errors in dependent kernels. Furthermore, the simplified batching logic in fused_gdn_gating.py may cause a performance regression for small batch sizes, and an optimized approach for these cases is recommended.
| from vllm.triton_utils import tl, triton | ||
|
|
||
| from vllm_ascend.ops.triton.triton_utils import get_element, get_vectorcore_num | ||
| from vllm_ascend.ops.triton.triton_utils import get_element |
There was a problem hiding this comment.
The function get_vectorcore_num is needed for the correct implementation of cal_grid_and_block_size but was removed. Please add it back to the imports.
| from vllm_ascend.ops.triton.triton_utils import get_element | |
| from vllm_ascend.ops.triton.triton_utils import get_element, get_vectorcore_num |
| def cal_grid_and_block_size(batch_size: int): | ||
| vectorcore_num = get_vectorcore_num() | ||
| if batch_size <= vectorcore_num: | ||
| grid = batch_size | ||
| block_size = 1 | ||
| else: | ||
| grid = vectorcore_num | ||
| block_size = triton.next_power_of_2(triton.cdiv(batch_size, grid)) | ||
| grid = batch_size | ||
| block_size = 64 | ||
| return grid, block_size |
There was a problem hiding this comment.
The new implementation of cal_grid_and_block_size is incorrect. It sets grid = batch_size and block_size = 64, which causes incorrect indexing within the Triton kernels that use it (e.g., rejection_greedy_sample_triton, expand_kernel). This will lead to bugs and incorrect results because the kernels are written to process block_size items per program instance, but with grid = batch_size, the indexing logic will be wrong. The previous implementation correctly calculated grid and block sizes. Please revert to the previous logic.
def cal_grid_and_block_size(batch_size: int):
vectorcore_num = get_vectorcore_num()
if batch_size <= vectorcore_num:
grid = batch_size
block_size = 1
else:
grid = vectorcore_num
block_size = triton.next_power_of_2(triton.cdiv(batch_size, grid))
return grid, block_size| progs = num_cores | ||
| row_per_core = triton.cdiv(batch, progs) | ||
| BLK_BATCHES = 64 | ||
| ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES) |
There was a problem hiding this comment.
The logic for determining progs, BLK_BATCHES, and ROW_ITER has been simplified, but this may cause a performance regression for small batch sizes. The previous logic handled small batches (batch <= num_cores) more efficiently by setting progs = batch and BLK_BATCHES = 1. With the new logic, if batch is small, many program instances in the grid will be idle or perform masked work, leading to inefficiency.
For example, if batch=1 and num_cores=32, the new logic launches a grid of 32 programs, but only one will do useful work. The old logic would have launched a grid of 1.
This suggestion re-introduces the efficient handling for small batches while keeping the simplified logic for larger batches.
| progs = num_cores | |
| row_per_core = triton.cdiv(batch, progs) | |
| BLK_BATCHES = 64 | |
| ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES) | |
| if batch <= num_cores: | |
| progs = batch | |
| BLK_BATCHES = 1 | |
| ROW_ITER = 1 | |
| else: | |
| progs = num_cores | |
| row_per_core = triton.cdiv(batch, progs) | |
| BLK_BATCHES = 64 | |
| ROW_ITER = triton.cdiv(row_per_core, BLK_BATCHES) |
…meter optimization
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?