-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[BugFix] Call Base Layer Directly if LoRA A/B in Parallel Vocab are 0 #29167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[BugFix] Call Base Layer Directly if LoRA A/B in Parallel Vocab are 0 #29167
Conversation
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a bug fix to bypass the LoRA forward pass in VocabParallelEmbedding when the LoRA weights are zero. This is a good optimization and addresses an edge case for specific models. The implementation is sound and a corresponding test has been added. However, I've identified a small but significant issue in the new test where variables are swapped, which could lead to confusion and potentially hide bugs in the future. My review includes a suggestion to correct this.
Signed-off-by: Alex-Brooks <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| # Short circuit and just run the base layer if either A or B | ||
| # are all zero we can just call the base layer directly. | ||
| if bool(torch.all(self.lora_a_stacked == 0)) or bool( | ||
| torch.all(self.lora_b_stacked == 0) | ||
| ): | ||
| return self.base_layer.forward(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Short-circuit scans full LoRA stacks every forward
The new zero check reduces over self.lora_a_stacked and self.lora_b_stacked on every call to forward, converting the GPU result to a Python bool. These tensors are sized max_loras × vocab_size × rank and live on GPU; scanning them each token adds O(max_loras·vocab) work plus a host sync even when LoRA weights are non-zero, which is far heavier than the previous gather-based path and will noticeably slow embedding lookups for any LoRA-enabled run. Consider caching a flag when weights are loaded instead of recomputing a full reduction per forward.
Useful? React with 👍 / 👎.
|
Give me some time to look into the root cause of this bug, thank you. |
|
Sure, thank you @jeejeelee! |
Purpose
Partial fix for #29166 - it doesn't fix the underlying edge case for the kernel being hit by granite speech, but it does at least fix the behavior for granite speech models, since none of them have LoRA weights for the embedding/lm head.
CC @jeejeelee @DarkLight1337
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.