-
Notifications
You must be signed in to change notification settings - Fork 453
feat: defer activation qparam calculation to sequential epoch end #2455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dzhengAP
wants to merge
4
commits into
vllm-project:main
Choose a base branch
from
dzhengAP:feat/deferred-activation-qparams
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
2320a6a
feat: defer activation qparam calculation to sequential epoch end
dzhengAP 26c29ad
fix: use update_deferred_stats for all observer types
dzhengAP 316114a
refactor: always disable quantization during calibration, re-enable f…
dzhengAP 684d710
refactor: address HDCharles review comments
dzhengAP File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| from torch.nn import Module | ||
|
|
||
| from llmcompressor.observers import Observer | ||
| from llmcompressor.observers.base import calibrate_module_from_observer | ||
|
|
||
| __all__ = [ | ||
| "initialize_observer", | ||
|
|
@@ -30,6 +31,7 @@ | |
| "calibrate_query_hook", | ||
| "calibrate_key_hook", | ||
| "calibrate_value_hook", | ||
| "flush_activation_qparams", | ||
| ] | ||
|
|
||
|
|
||
|
|
@@ -156,15 +158,20 @@ def update_weight_zp_scale(module: Module): | |
| call_observer(module=module, base_name="weight") | ||
|
|
||
|
|
||
| def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): | ||
| def calibrate_activations( | ||
| module: Module, value: torch.Tensor, base_name: str, stats_only: bool = False | ||
| ): | ||
| """ | ||
| Calibrate input or output activations by calling the a module's attached | ||
| observer. | ||
|
|
||
| :param module: torch.nn.Module | ||
| :param base_name: substring used to fetch the observer, scales, and zp | ||
| :param value: torch.Tensor to be passed to the observer | ||
|
|
||
| :param stats_only: if True, only update running statistics in the observer | ||
| (accumulate min/max) without computing or writing scale/zero_point. | ||
| Used during deferred qparam calibration — qparams are computed once | ||
| at epoch end via flush_activation_qparams instead of per batch. | ||
| """ | ||
| # If empty tensor, can't update zp/scale | ||
| # Case for MoEs | ||
|
|
@@ -184,6 +191,15 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): | |
| if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP: | ||
| calculate_gparam = True | ||
|
|
||
| # In deferred (stats_only) mode: call the observer to accumulate running | ||
| # min/max stats but do NOT write scale/zero_point yet. | ||
| # Qparams are written once at epoch end via flush_activation_qparams. | ||
| if stats_only: | ||
| observer = getattr(module, f"{base_name}_observer", None) | ||
| if observer is not None: | ||
| observer.update_deferred_stats(value) | ||
| return | ||
|
|
||
| call_observer( | ||
| module=module, | ||
| base_name=base_name, | ||
|
|
@@ -195,44 +211,40 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): | |
|
|
||
| def calibrate_input_hook(module: Module, args: Any): | ||
| """ | ||
| Hook to calibrate input activations. | ||
| Will call the observers to update the scales/zp before applying | ||
| input QDQ in the module's forward pass. | ||
| Hook to accumulate input activation statistics (min/max) in the observer. | ||
| Scale and zero_point are not written here; they are computed once per subgraph | ||
| at epoch end via flush_activation_qparams. | ||
| """ | ||
| args = args[0] if isinstance(args, tuple) else args | ||
| calibrate_activations(module, value=args, base_name="input") | ||
| calibrate_activations(module, value=args, base_name="input", stats_only=True) | ||
|
|
||
|
|
||
| def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): | ||
| """ | ||
| Hook to calibrate output activations. | ||
| Will call the observers to update the scales/zp before applying | ||
| output QDQ. | ||
| Hook to accumulate output activation statistics (min/max) in the observer. | ||
| Scale and zero_point are not written here; they are computed once per subgraph | ||
| at epoch end via flush_activation_qparams. | ||
| Note: forward_quantize is intentionally absent — hooks only collect statistics. | ||
| """ | ||
| calibrate_activations( | ||
| module, | ||
| value=output, | ||
| base_name="output", | ||
| ) | ||
| output = forward_quantize( | ||
| module=module, | ||
| value=output, | ||
| base_name="output", | ||
| args=module.quantization_scheme.output_activations, | ||
| stats_only=True, | ||
| ) | ||
| return output | ||
|
|
||
|
|
||
| def calibrate_query_hook(module: Module, query_states: torch.Tensor): | ||
| calibrate_activations(module, query_states, base_name="q") | ||
| calibrate_activations(module, query_states, base_name="q", stats_only=True) | ||
|
|
||
|
|
||
| def calibrate_key_hook(module: Module, key_states: torch.Tensor): | ||
| calibrate_activations(module, key_states, base_name="k") | ||
| calibrate_activations(module, key_states, base_name="k", stats_only=True) | ||
|
|
||
|
|
||
| def calibrate_value_hook(module: Module, value_states: torch.Tensor): | ||
| calibrate_activations(module, value_states, base_name="v") | ||
| calibrate_activations(module, value_states, base_name="v", stats_only=True) | ||
|
|
||
|
|
||
| def apply_calibration_status(module: Module): | ||
|
|
@@ -273,3 +285,29 @@ def reset_quantization_status(model: Module): | |
| for module in model.modules(): | ||
| if hasattr(module, "quantization_status"): | ||
| delattr(module, "quantization_status") | ||
|
|
||
|
|
||
| def flush_activation_qparams(module: Module): | ||
| """ | ||
| Compute and write final activation qparams from each observer's accumulated | ||
| running statistics, then free those statistics to reduce memory. | ||
|
|
||
| This is called once at SEQUENTIAL_EPOCH_END for each subgraph, replacing the | ||
| per-batch qparam updates that were previously triggered by calibration hooks. | ||
| It is a no-op for modules with no quantization scheme or no activation observers. | ||
|
|
||
| Note: weight observers are not touched here — weight qparams are always computed | ||
| up-front in ``on_start`` via ``update_weight_zp_scale``. | ||
|
|
||
| apply to targeted modules with: | ||
| for _, module in match_named_modules(...): | ||
| flush_activation_qparams(module) | ||
|
|
||
| :param module: module to flush activation qparams for | ||
| """ | ||
| scheme = getattr(module, "quantization_scheme", None) | ||
| if scheme is None: | ||
| return | ||
|
|
||
| for base_name in ("input", "output", "q", "k", "v"): | ||
dzhengAP marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| calibrate_module_from_observer(module, base_name) | ||
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.