Skip to content

Conversation

@vkuzo
Copy link

@vkuzo vkuzo commented Nov 21, 2025

Summary:

not for land, just a demo

  1. during weight loading, keep track of how many elements we have loaded
  2. when we have loaded all the elements, call post-processing

can be used to call weight post-processing in a streaming fashion to minimize GPU memory usage. Will only work if we can assume we only load each weight chunk once.

Test Plan:

tested locally with facebook/opt-125m and fp8 online quantization

Reviewers:

Subscribers:

Tasks:

Tags:

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a proof-of-concept for online FP8 quantization with streaming weight post-processing. The approach is clever, patching the weight loader to trigger post-processing as soon as a weight tensor is fully loaded. This can help reduce peak memory usage during model loading.

My main feedback is to improve the robustness of the state management. The flag _already_called_process_weights_after_loading is currently stored on the Fp8LinearMethod instance. While this works with the current code structure, it's fragile. Attaching this state to the layer object instead would make the implementation more robust against future changes, such as instance reuse for optimization. I've added specific comments with code suggestions to address this.

del param._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
self._already_called_process_weights_after_loading = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Storing _already_called_process_weights_after_loading on self (the Fp8LinearMethod instance) makes the design fragile. Although a new instance is currently created for each layer, this might change in the future (e.g., for optimization), which could lead to this flag persisting incorrectly across different layers.

To make this more robust, this state should be attached to the layer object, which is guaranteed to be unique. This change should be made in conjunction with the corresponding check in process_weights_after_loading.

Suggested change
self._already_called_process_weights_after_loading = True
layer._already_called_process_weights_after_loading = True

layer.register_parameter("input_scale", None)

def process_weights_after_loading(self, layer: Module) -> None:
if getattr(self, "_already_called_process_weights_after_loading", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To make the state management more robust and in conjunction with the suggested change for setting this flag, this check should be on the layer object instead of self.

Suggested change
if getattr(self, "_already_called_process_weights_after_loading", False):
if getattr(layer, "_already_called_process_weights_after_loading", False):

Summary:

not for land, just a demo

1. during weight loading, keep track of how many elements we have loaded
2. when we have loaded all the elements, call post-processing

can be used to call weight post-processing in a streaming fashion
to minimize GPU memory usage. Will only work if we can assume we only
load each weight chunk once.

Test Plan:

tested locally with facebook/opt-125m and `fp8` online quantization

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by:  <[email protected]>
@vkuzo vkuzo force-pushed the 20251121_fp8_online_quant_hack branch from 5326892 to 9583e3b Compare November 21, 2025 19:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant