-
Notifications
You must be signed in to change notification settings - Fork 600
test vllm-ascend triton ops #4418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
ok |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several Triton kernels to optimize the rejection sampling process in vllm-ascend, replacing existing PyTorch implementations. The changes aim to improve performance by leveraging GPU-specific optimizations. The PR also includes refactoring of helper functions for better vectorization and clarity. My review identified a critical bug in the new rejection_greedy_sample_kernel where a condition to filter non-greedy requests is incorrect, leading to wrong behavior. I also found a high-severity issue in the refactored sample_recovered_tokens_pytorch function, where an incorrect tensor slicing could cause runtime errors. I've provided suggestions to fix both issues.
| if is_greedy is None: | ||
| # Early exit for non-greedy sampling requests. | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition is_greedy is None is incorrect for a Triton kernel. The is_greedy variable, being a Triton boolean scalar (tl.int1), can never be Python's None. This bug prevents the intended early exit for non-greedy sampling requests, causing them to be processed incorrectly by this kernel designed for greedy sampling. The check should be if not is_greedy: to correctly handle non-greedy cases.
| if is_greedy is None: | |
| # Early exit for non-greedy sampling requests. | |
| return | |
| if not is_greedy: | |
| # Early exit for non-greedy sampling requests. | |
| return |
| q_values = torch.full((num_tokens, vocab_size), | ||
| float('-inf'), | ||
| device=q.device) | ||
| q_values[:vocab_size] = q[req_idx, :vocab_size] | ||
|
|
||
| recovered_id = torch.argmax(prob / q_values).item() | ||
| output_token_ids[token_idx] = recovered_id | ||
| q_values[:vocab_size] = q_value_new[token_positions, :vocab_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current initialization and population of q_values is incorrect and overly complex. The row-wise slice q_values[:vocab_size] can cause a shape mismatch error if num_tokens > vocab_size. This block can be simplified by directly assigning q_value_new to q_values, which also removes the redundant torch.full initialization.
q_values = q_value_new|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?