[Feat]skip lightning indexer for the first 2048 preceding tokens#7418
[Feat]skip lightning indexer for the first 2048 preceding tokens#74181024daniel wants to merge 1 commit intovllm-project:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant performance optimization by allowing the lightning indexer to skip the initial 2048 tokens during computation. This change aims to reduce unnecessary processing, thereby enhancing the efficiency of the attention mechanism. The implementation involves extending existing metadata structures, adding new utility functions for index manipulation, and integrating this logic seamlessly into the model's execution flow. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a "Lightning Indexer Skip" feature to optimize attention computation by selectively skipping tokens based on a defined threshold. This involves adding new metadata fields to AscendSFAMetadata and AscendCommonAttentionMetadata, implementing utility functions like get_sfa_skip_indices and get_index_of_skipped_queries_numpy for index calculation and reordering, and modifying the attention and model execution logic to handle skipped sequences. Specifically, the indexer_select_post_process method was refactored to conditionally process tokens and concatenate results from skipped and non-skipped sequences, and the forward method was updated to prevent kv_cache updates for fully skipped sequences. A new enable_lightning_indexer_skip function was added to control this feature. Review comments suggest fixing a typo in skip_threold to skip_threshold and defining 2048 as a shared constant to improve readability and maintainability, as well as removing a redundant import numpy as np statement from within a function.
| def get_sfa_skip_indices(num_comptuted_tokens, query_lens): | ||
| num_comptuted_tokens = to_numpy(num_comptuted_tokens) | ||
| query_lens = to_numpy(query_lens) | ||
| skip_threold = 2048 |
There was a problem hiding this comment.
The variable skip_threold has a typo and should be skip_threshold. Additionally, the value 2048 is a magic number. It should be defined as a constant with a descriptive name at the module level, for example, LIGHTNING_INDEXER_SKIP_THRESHOLD = 2048. This will improve code readability and make it easier to change this value in the future if needed. The same magic number 2048 is also used in get_index_of_skipped_queries_numpy in vllm_ascend/attention/sfa_v1.py. Using a shared constant would be ideal.
| skip_threold = 2048 | |
| skip_threshold = 2048 |
| actual_seq_lengths_query = to_numpy(actual_seq_lengths_query) | ||
| actual_seq_lengths_key = to_numpy(actual_seq_lengths_key) | ||
| num_actual_seqs = to_numpy(num_actual_seqs) | ||
| import numpy as np |
There was a problem hiding this comment.
The import numpy as np statement is inside the get_index_of_skipped_queries_numpy function. According to Python best practices, imports should be at the top of the module to improve readability and avoid repeated imports. Since numpy is already imported at the top of the file in this pull request, this line is redundant and should be removed.
There was a problem hiding this comment.
has put import expression to the top
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
vllm_ascend/attention/utils.py
Outdated
| return x | ||
|
|
||
|
|
||
| def get_sfa_skip_indices(num_comptuted_tokens, query_lens): |
There was a problem hiding this comment.
I think better to rename this func name, like "get_li_skip_indices)".
There was a problem hiding this comment.
has change the func name
vllm_ascend/attention/sfa_v1.py
Outdated
| # ========================= | ||
| if attn_metadata.skip: | ||
| num_tokens = attn_metadata.non_skip_num_actual_tokens | ||
| if num_tokens > 0: |
There was a problem hiding this comment.
I suggest that you could refactor this logic after capacity finish, in order to separate skip and non-skip token sequences for function invocation.
vllm_ascend/attention/sfa_v1.py
Outdated
| k_li = self._get_full_kv(k_li, attn_metadata) | ||
|
|
||
| if kv_cache is not None: | ||
| if kv_cache is not None and (not attn_metadata.skip or attn_metadata.non_skip_num_actual_tokens > 0): |
There was a problem hiding this comment.
This condition may lead to precision issues. The skip indices are introduced to reduce matmul and LI operator computation, but k_li still needs to be stored globally for use in subsequent scheduling batches.
There was a problem hiding this comment.
has fix this branch condition
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
4c0f99e to
3fbbf99
Compare
b823a22 to
3e51b42
Compare
Co-authored-by: YzTongNiar <1667927948@qq.com> Co-authored-by: wyh145 <1987244901@qq.com> Signed-off-by: 1024daniel <xxltju324@gmail.com>
| # Metadata for Prefill Context Parallelism (PCP) operations. | ||
| prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata | None = None | ||
|
|
||
| lightning_indexer_metadata: AscendLightningIndexerMetadata | None = None |
There was a problem hiding this comment.
We can refactor 'lightning_indexer_metadata' as 'lightning_indexer_context' just like 'dsa_cp_context' in the sfa metadata builder.
| .pin_memory() | ||
| .to(dtype=torch.bool, device=self.device, non_blocking=True) | ||
| ) | ||
| common_attn_metadata.lightning_indexer_metadata = AscendLightningIndexerMetadata( |
There was a problem hiding this comment.
Refactor 'AscendLightningIndexerMetadata' as 'lightling_indexer_context' (just like 'dsa_cp_context') to avoid building an extra metadata.
| seq_lens = common_attn_metadata.seq_lens[:num_reqs] | ||
|
|
||
| query_start_loc = common_attn_metadata.query_start_loc[: num_reqs + 1] | ||
| tokens = query_start_loc[1:] - query_start_loc[:-1] |
There was a problem hiding this comment.
should rename 'tokens' as 'num_computed_tokens' to clarify its usage.
| dsa_cp_context = None | ||
| if self.enable_dsa_cp: | ||
| num_of_non_skip_tokens = 0 | ||
| num_segs_for_cp = cum_query_lens.shape[0] |
There was a problem hiding this comment.
should rename 'num_segs_for_cp' as 'num_segs' since it does not depend on cp actually.
What this PR does / why we need it?
Skip first 2048 tokens for lighting indexer to avoid redundant computation
Does this PR introduce any user-facing change?
No
How was this patch tested?