Skip to content

Conversation

brian-dellabetta
Copy link
Collaborator

@brian-dellabetta brian-dellabetta commented Aug 21, 2025

SUMMARY:
Prerequisites:

This allows for multi-modifier support by scoping the application of quantization config/status to only the modules in the model that match the given targets/ignore configuration, rather than all modules. Initialization of observers is moved to on_start (instead of on_initialize) to match their removal on_end (and not on_finalize). This prevents collision during the multi-modifier lifecycle

  • Update AWQ
  • Update QuantizationModifier
  • Update QuantizationMixin
  • Update GPTQ
  • No other quantization modifiers exist

TEST PLAN:

  • Tests were added to [Multi-Modifier] Scoped apply quantization config neuralmagic/compressed-tensors#432 to confirm correct application of multiple modifiers.
  • Added an example in this PR to show how AWQ and GPTQ can be applied heterogeneously to a model, along with a small README. Logs show alternating AWQ and GPTQ messages for "sequential", and correct behavior for "independent" pipelines. Model checkpoint for the sequential pipeline shows correct application of W8A8 to self_attn layers and W4A16 to mlp layers. config.json and safetensors weights all look as expected

Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@brian-dellabetta brian-dellabetta force-pushed the bdellabe/scoped-quant-status branch 2 times, most recently from 5fec983 to 2f93072 Compare August 28, 2025 16:51
@brian-dellabetta brian-dellabetta changed the title [Multi-modifier] Support scoped appliation of quantization config/status [Multi-modifier] Support scoped application of quantization config/status Sep 2, 2025
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
@brian-dellabetta brian-dellabetta force-pushed the bdellabe/scoped-quant-status branch from 2f93072 to f99db2f Compare September 11, 2025 16:43
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
@brian-dellabetta brian-dellabetta marked this pull request as ready for review September 15, 2025 20:38
@brian-dellabetta brian-dellabetta added the ready When a PR is ready for review label Sep 15, 2025
Signed-off-by: Brian Dellabetta <[email protected]>
@brian-dellabetta brian-dellabetta removed the ready When a PR is ready for review label Sep 15, 2025
Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding some basic tests/ common use cases, otherwise looks good!

kylesayrs
kylesayrs previously approved these changes Sep 18, 2025
Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job

Signed-off-by: Brian Dellabetta <[email protected]>
Copy link
Collaborator

@shanjiaz shanjiaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woohoo!

Copy link
Collaborator

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@brian-dellabetta brian-dellabetta merged commit 27303c4 into main Sep 22, 2025
8 checks passed
@brian-dellabetta brian-dellabetta deleted the bdellabe/scoped-quant-status branch September 22, 2025 21:25
def test_serialize_actorder(has_actorder, actorder, exp_actorder):
if has_actorder:
modifier = GPTQModifier(targets=["Linear"], actorder=actorder)
modifier = GPTQModifier(targets=["Linear"], scheme="W8A8", actorder=actorder)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was this targeting before you added the scheme?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it just passed init/validation but was never used. It was never applied to a model, so it would've never worked, i just added it to make sure improper configuration wasn't the reason the test was failing (it was ultimately something else causing the test to fail)

self._calibration_hooks = self._initialize_hooks(model)
model.apply(apply_calibration_status)
for _, module in match_named_modules(model, self.targets, self.ignore):
self._initialize_observers(module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we keep this iniitialize_quantization?

Copy link
Collaborator Author

@brian-dellabetta brian-dellabetta Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

observers should be initialized on start to align with them being removed on_end. so this was moved into on_start instead. without this change the lifecycle with multiple quant modifiers will trigger observer hooks before the modifier starts (before it sees any data), which can now happen in a previous modifier lifecycle

brian-dellabetta added a commit that referenced this pull request Sep 26, 2025
…les (#1869)

SUMMARY:
#1772 introduced a bug when running NVFP4 quantization schemes. The call
to `update_fused_layer_weight_global_scales` needs to be run on
Attention and MLP layers, which are not included in `targets` consisting
of quantizable layers inside Attention/MLP. This PR fixes that by
running `update_fused_layer_weight_global_scales` on every module
instead of the targeted ones, which is ok because the call is idempotent
and will only modify if the modules have NVFP4 schemes. This is only a
problem in `QuantizationModifier`, AWQ cannot be used with NVFP4.

TEST PLAN:
Confirmed that the working vs. broken global scales are mismatched
because the update is never run:
```
model.layers.0.self_attn.k_proj.weight_global_scale -- working 9600.0, broken 12992.0
model.layers.0.self_attn.q_proj.weight_global_scale -- working 9600.0, broken 9600.0
model.layers.0.self_attn.v_proj.weight_global_scale -- working 9600.0, broken 12160.0
```

And these changes resolve the regression:
Before
```
vllm (pretrained=/home/dsikka/llm-compressor/examples/quantization_w4a4_fp4/Qwen3-30B-A3B-NVFP4,dtype=auto,max_model_len=4096,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8135|±  |0.0107|
|     |       |strict-match    |     5|exact_match|↑  |0.8097|±  |0.0108|
```
After
```
vllm (pretrained=/home/brian-dellabetta/projects/llm-compressor/Qwen3-30B-A3B-NVFP4,dtype=auto,max_model_len=4096,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8620|±  |0.0095|
|     |       |strict-match    |     5|exact_match|↑  |0.8575|±  |0.0096|
```

---------

Signed-off-by: Brian Dellabetta <[email protected]>
brian-dellabetta added a commit that referenced this pull request Sep 30, 2025
SUMMARY:
This PR
- [x] Removes `pile-val-dataset` from e2e tests, as it is no longer used
in examples and the processing logic was flawed
- [x] Fixes a model validation error introduced in #1772 that was
preventing AWQModifier from running one of the validations, causing it
to be in an invalid state (`AWQModifier.validate_model_after` was
preventing `QuantizationMixin.validate_model_after` from running). With
these changes, tests pass and the compressed model generates meaningful
responses. It was previously generating all 0s


TEST PLAN:
`CADENCE=nightly
TEST_DATA_FILE=tests/e2e/vLLM/configs/w4a16_grouped_quant_sym_awq.yaml
pytest -s tests/e2e/vLLM/test_vllm.py` and
`CADENCE=nightly
TEST_DATA_FILE=tests/e2e/vLLM/configs/w4a16_grouped_quant_asym_awq.yaml
pytest -s tests/e2e/vLLM/test_vllm.py`
both pass with output like
```
PROMPT:
The capital of France is
GENERATED TEXT:
 Paris, which is also the country's largest city.

PROMPT:
The president of the US is
GENERATED TEXT:
 named, but the name of the Vice President is not given. In the case

PROMPT:
My name is
GENERATED TEXT:
 Emily and I am from Canada. I have always been fascinated with
```

---------

Signed-off-by: Brian Dellabetta <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready When a PR is ready for review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants