-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Perf] Add decode full-graph support to FlashInfer-MLA backend #26313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Perf] Add decode full-graph support to FlashInfer-MLA backend #26313
Conversation
Signed-off-by: Benjamin Chislett <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly enables full CUDA graph support for decode operations in the FlashInfer-MLA attention backend. The change is implemented by creating a new FlashInferMLAMetadataBuilder
class that inherits from MLACommonMetadataBuilder
and sets the cudagraph_support
attribute to AttentionCGSupport.UNIFORM_BATCH
. The FlashInferMLABackend
is then updated to use this new builder. The approach is clean, follows the existing design patterns in the codebase, and seems to correctly enable the feature as described. The changes are minimal and well-targeted. I found no issues of high or critical severity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
…project#26313) Signed-off-by: Benjamin Chislett <[email protected]>
Purpose
The annotation was missing from FlashInfer-MLA while the implementation has support.
Running DSR1-FP4 on 4xB200 gets me 97 TPS:
I also tested on a local development branch for MTP containing #25984, and #25987.
On that branch, with 3 MTP speculative tokens, I get 165 TPS and passing GSM8k evals.
Test Plan
GSM8k run as follows:
Test Result
Matches the baseline: