-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Fix max_seq_len_k in trtllm_mha attention backend #9416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Qiaolin-Yu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses an issue in the trtllm_mha
attention backend where the max_seq_len_k
parameter was being inconsistently handled. The fix ensures that max_seq_len_k
within the forward metadata accurately reflects the maximum sequence length of the current batch, while the underlying attention kernel consistently receives the global max_context_len
for proper operation. This correction aims to improve the stability and correctness of the attention mechanism, particularly when using CUDA graphs.
Highlights
- Corrected
max_seq_len_k
calculation: Themax_seq_len_k
attribute within the forward metadata is now dynamically calculated based on the maximum sequence length present in the current batch, rather than using a global maximum context length. - Consistent use of
max_context_len
in attention kernel calls: Themax_seq_len
andmax_kv_len
parameters passed to the core attention kernel (_attn_fwd
) in bothforward_decode
andforward_extend
methods now consistently useself.max_context_len
, ensuring the kernel operates with the correct overall context capacity. - Enhanced stability and correctness: These changes resolve an inconsistency in how sequence length parameters were handled, leading to improved stability and correctness for the
trtllm_mha
attention backend, especially when leveraging CUDA graphs.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes an issue with the max_seq_len_k
parameter in the trtllm_mha
attention backend. The changes properly distinguish between the maximum sequence length within a batch and the model's overall maximum context length. The modifications ensure that metadata.max_seq_len_k
accurately reflects the current batch's maximum sequence length, which is semantically correct. Most importantly, the calls to the underlying flashinfer
kernels in forward_decode
and forward_extend
are updated to use self.max_context_len
. This is the correct approach, as these kernels likely require the maximum potential length for which the KV cache structures are sized, rather than the dynamic maximum length of the current batch. The fix is logical, well-targeted, and resolves the inconsistency in the original implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. The max_seq_len should be >= the page table stride for trtllm-gen family of attention kernels.
We could continue on flashinfer-ai/flashinfer#1407 to remove this param, after the new cubin publishing & launcher params refactor flashinfer-ai/flashinfer#1518. cc @yzh119 for confirming this todo item.
Motivation
Modifications
Accuracy Tests
Accuracy: 0.7089646464646465
Benchmarking and Profiling
Checklist