Skip to content

Conversation

@alex-jw-brooks
Copy link
Contributor

@alex-jw-brooks alex-jw-brooks commented Nov 21, 2025

Purpose

Partial fix for #29166 - it doesn't fix the underlying edge case for the kernel being hit by granite speech, but it does at least fix the behavior for granite speech models, since none of them have LoRA weights for the embedding/lm head.

CC @jeejeelee @DarkLight1337

Test Plan

  • Add a test to make sure results are equivalent to just calling the base layer

Test Result

  • Test passes
  • Granite speech models are able to avoid the illegal memory access for this edge case

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a bug fix to bypass the LoRA forward pass in VocabParallelEmbedding when the LoRA weights are zero. This is a good optimization and addresses an edge case for specific models. The implementation is sound and a corresponding test has been added. However, I've identified a small but significant issue in the new test where variables are swapped, which could lead to confusion and potentially hide bugs in the future. My review includes a suggestion to correct this.

Signed-off-by: Alex-Brooks <[email protected]>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 93 to +99
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Short circuit and just run the base layer if either A or B
# are all zero we can just call the base layer directly.
if bool(torch.all(self.lora_a_stacked == 0)) or bool(
torch.all(self.lora_b_stacked == 0)
):
return self.base_layer.forward(x)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Short-circuit scans full LoRA stacks every forward

The new zero check reduces over self.lora_a_stacked and self.lora_b_stacked on every call to forward, converting the GPU result to a Python bool. These tensors are sized max_loras × vocab_size × rank and live on GPU; scanning them each token adds O(max_loras·vocab) work plus a host sync even when LoRA weights are non-zero, which is far heavier than the previous gather-based path and will noticeably slow embedding lookups for any LoRA-enabled run. Consider caching a flag when weights are loaded instead of recomputing a full reduction per forward.

Useful? React with 👍 / 👎.

@jeejeelee
Copy link
Collaborator

Give me some time to look into the root cause of this bug, thank you.

@alex-jw-brooks
Copy link
Contributor Author

Sure, thank you @jeejeelee!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants