Skip to content

Conversation

@kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Dec 1, 2025

Purpose

  • In order to support calibrating with large batch sizes, the lm_head must be skipped, otherwise too much memory will be used
  • This implementation ensures that lm_head inputs can still be calibrated. While this technically invalidates calibrating output activations of the lm_head, this case is very outside of the scope of supported features that I think it can be safely ignored/ allowed to error.
With LM Head LM Head Disabled
without_disable with_disable

Llama 8B, 32 samples, sequence length 2048, batch size 16

Prerequisites

Changes

  • disable_lm_head replaces the module forward with a dummy forward function which returns meta outputs
    • Because calibration hooks trigger before forward is called, lm_head inputs can still be calibrated
    • This is safe with transformers implementations, as there are no implementations which access lm_head weights directly
    • This function is always enabled as a subcontext of calibration_forward_context

Testing

  • Added test_disable_lm_head
  • Added lm head disabling tests for test_calibration_forward_context

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @kylesayrs, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the memory efficiency and robustness of the model calibration process. By introducing a smart mechanism to conditionally offload the Language Model (LM) head to a meta device, it mitigates Out-Of-Memory issues that can arise during calibration, particularly when using large batch sizes. This ensures that the LM head is only active when a specific modifier requires its calibration, leading to more optimized resource utilization and a smoother calibration experience.

Highlights

  • Memory Optimization for Calibration: Introduced a mechanism to disable the LM head during model calibration by moving it to a meta device, preventing Out-Of-Memory (OOM) errors, especially with large batch sizes.
  • Conditional LM Head Handling: Added a utility function (requires_lm_head_calibration) to determine if any modifiers explicitly require the LM head for calibration, ensuring it's only disabled when safe.
  • Pipeline Integration: Updated both basic and sequential calibration pipelines to leverage the new conditional LM head disabling logic, improving efficiency across different calibration workflows.
  • Refined Embedding Untying: Modified the SpinQuant modifier to conditionally untie word embeddings and the LM head only when they are explicitly targeted, allowing for more precise control and potential memory savings.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization to disable the lm_head during calibration to prevent out-of-memory errors, which is a significant improvement for handling large batch sizes. The implementation is well-structured, adding new utility functions and integrating them into the calibration pipelines. I've identified one potential high-severity issue in the requires_lm_head_calibration function that could lead to incorrect calibration behavior. My review includes a specific comment and a code suggestion to address this.

@kylesayrs kylesayrs force-pushed the kylesayrs/modifiers-expose-targets branch 3 times, most recently from 9536dad to d4b3e7a Compare December 1, 2025 23:42
return input.to("meta") @ dummy_weight.T

with patch_attr(lm_head, "forward", dummy_forward.__get__(lm_head)):
yield
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still getting used to the code paradigms here, this nested context manager design is an interesting approach. I'm pretty sure i had similar situations in torchao and I never would have considered this.

Copy link
Collaborator Author

@kylesayrs kylesayrs Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I love the context manager pattern, as it makes it reminds implementers to clean up any side effects which might affect downstream code.

HDCharles
HDCharles previously approved these changes Dec 2, 2025
Copy link
Collaborator

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good

Base automatically changed from kylesayrs/refactor-embeddings to main December 2, 2025 16:38
@kylesayrs kylesayrs dismissed HDCharles’s stale review December 2, 2025 16:38

The base branch was changed.

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs kylesayrs force-pushed the kylesayrs/modifiers-expose-targets branch from 34814c7 to 6559de0 Compare December 2, 2025 19:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants