[Refactor] optimize sample_recover method in reject_sampler#3727
[Refactor] optimize sample_recover method in reject_sampler#3727lio1226 wants to merge 4 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Code Review
The pull request optimizes the sample_recovered_tokens_pytorch method in rejection_sampler.py to improve the performance of eagle-3. The optimization replaces the nested loops with vectorized operations using torch functions, which should reduce the execution time. I have identified a potential issue related to the indexing of q_values.
|
|
||
| recovered_id = torch.argmax(prob / q_values).item() | ||
| output_token_ids[token_idx] = recovered_id | ||
| q_values[:vocab_size] = q_value_new[token_positions, :vocab_size] |
There was a problem hiding this comment.
The indexing q_values[:vocab_size] might lead to incorrect behavior. q_values is initialized with the shape (num_tokens, vocab_size), and q_value_new has the shape (num_tokens, vocab_size). Therefore, assigning q_value_new[token_positions, :vocab_size] to q_values[:vocab_size] will result in q_values having only the first vocab_size rows updated, while the rest of the rows will remain -inf. This is likely not the intended behavior, as it will skew the probability distribution for tokens beyond the first vocab_size positions. Consider assigning q_value_new to q_values directly.
To fix this, you should assign the entire q_value_new to q_values without slicing. This ensures that all token positions have the correct q-values for the subsequent argmax operation.
Severity: critical
| q_values[:vocab_size] = q_value_new[token_positions, :vocab_size] | |
| q_values = q_value_new |
Signed-off-by: lio <1983142975@qq.com>
c3edfda to
f128fd5
Compare
Signed-off-by: lio <1983142975@qq.com>
Signed-off-by: lio <1983142975@qq.com>
Signed-off-by: lio <1983142975@qq.com>
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
|
Any progress? If this PR is still alive, please rebase to main and make CI happy. Thanks |
The main branch has now completed the work of this pull request, so this pr can be closed. |
What this PR does / why we need it?
We optimized the sample_recovered_tokens_pytorch method reject_sampler and improve the performance of eagle-3.
Does this PR introduce any user-facing change?
How was this patch tested?
None
Co-authored-by: QilaiZhang (245706640@qq.com )
vLLM version: v0.11.0rc3
vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0