-
Notifications
You must be signed in to change notification settings - Fork 600
[Kernel] add triton kernels for sampling #4394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request replaces several PyTorch-based sampling functions with Triton kernels, aiming to improve performance. The new kernels for expand_kernel and sample_recovered_tokens_kernel are well-implemented. However, rejection_greedy_sample_kernel and rejection_random_sample_kernel contain sequential loops over draft tokens. This is a performance anti-pattern in Triton that prevents vectorization and may not yield the desired speed-up. I've added specific comments with suggestions to vectorize these kernels for better efficiency.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: MidnightSun <[email protected]>
| max_spec_len, | ||
| is_greedy, | ||
| ) | ||
| rejection_greedy_sample_kernel[(batch_size,)]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After communicating with @wangxiyuan, you need to use HAS_TRITON check here. If user hasn't installed triton, fall back to original implementation to ensure functionality. An example is here:
vllm-ascend/vllm_ascend/attention/sfa_v1.py
Line 505 in a3225c4
| if HAS_TRITON: |
| else: | ||
| from vllm.v1.sample.rejection_sampler import apply_sampling_constraints | ||
|
|
||
| import triton.language as tl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use from vllm.triton_utils import tl, triton here, same reason.
|
In conclusion, currently we need to make sure that original functionality won't be broken by new triton optimization in environments without triton. |
What this PR does / why we need it?
Replace pyorch implement of sampling with triton kernels
Does this PR introduce any user-facing change?
No
How was this patch tested?