Skip to content

Commit 810d7fd

Browse files
committed
feat: add iMatrix weighted MSE observer and IMatrixGatherer
- imatrix_mse observer with E[x²] importance weighting - IMatrixGatherer transform using match_named_modules + CPU offload - Unit tests for observer and gatherer - E2E integration tests RFC #2456 Signed-off-by: Gilles Turpin <turpingilles15@gmail.com>
1 parent cf3bd64 commit 810d7fd

File tree

12 files changed

+1471
-0
lines changed

12 files changed

+1471
-0
lines changed

examples/quantization_w4a16/README.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,105 @@ We can see the resulting scores look good!
138138
| | |strict-match | 5|exact_match||0.720|± |0.0285|
139139
```
140140

141+
---
142+
143+
## iMatrix Importance-Weighted Quantization
144+
145+
`imatrix_mse` is an observer that uses per-channel activation importance (E[]) to weight quantization error during range selection. Channels that carry more signal get more careful range optimization.
146+
147+
Two components work together:
148+
- **`IMatrixGatherer`**: triggers a calibration pass so the observer can collect importance data
149+
- **`imatrix_mse` observer**: collects E[] per input channel via forward pre-hooks and uses importance weighting in the MSE grid search: `err = sum(importance * |Q(w) - w|^p)`
150+
151+
> See [RFC #2456](https://github.com/vllm-project/llm-compressor/discussions/2456) for the full design discussion.
152+
153+
### Usage
154+
155+
```bash
156+
python3 llama3_imatrix_example.py
157+
```
158+
159+
The simplest setup uses `preset_name_to_scheme` to configure W4A16 and swaps in the `imatrix_mse` observer:
160+
161+
```python
162+
from compressed_tensors.quantization import preset_name_to_scheme
163+
from llmcompressor.modifiers.quantization import QuantizationModifier
164+
from llmcompressor.modifiers.transform.imatrix import IMatrixGatherer
165+
166+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
167+
scheme.weights.observer = "imatrix_mse"
168+
169+
recipe = [
170+
IMatrixGatherer(ignore=["lm_head"]),
171+
QuantizationModifier(
172+
config_groups={"group_0": scheme},
173+
ignore=["lm_head"],
174+
),
175+
]
176+
```
177+
178+
### Composing with GPTQ
179+
180+
iMatrix composes with GPTQ by providing importance-weighted ranges for the Hessian-based rounding:
181+
182+
```python
183+
from llmcompressor.modifiers.gptq import GPTQModifier
184+
185+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
186+
scheme.weights.observer = "imatrix_mse"
187+
188+
recipe = [
189+
IMatrixGatherer(ignore=["lm_head"]),
190+
GPTQModifier(
191+
config_groups={"group_0": scheme},
192+
ignore=["lm_head"],
193+
),
194+
]
195+
```
196+
197+
### Results
198+
199+
W4A16, Llama-3.1-8B, WikiText-2 token-level perplexity (141 chunks x 2048):
200+
201+
**group_size=128:**
202+
203+
| Config | PPL |
204+
|---|---|
205+
| FP16 baseline | 6.24 |
206+
| RTN `memoryless_minmax` | 6.96 |
207+
| RTN `imatrix_mse` | 6.97 |
208+
| GPTQ | 6.89 |
209+
| GPTQ + `imatrix_mse` | 6.82 |
210+
211+
**group_size=32:**
212+
213+
| Config | PPL |
214+
|---|---|
215+
| RTN `memoryless_minmax` | 6.74 |
216+
| RTN `imatrix_mse` | 6.73 |
217+
| GPTQ | 6.70 |
218+
| GPTQ + `imatrix_mse` | 6.66 |
219+
220+
GPTQ + `imatrix_mse` is the best result at both group sizes with default observer settings. iMatrix never degrades quality.
221+
222+
### Observer Parameters
223+
224+
The observer accepts optional `observer_kwargs` for fine-tuning:
225+
226+
| Parameter | Default | Description |
227+
|---|---|---|
228+
| `norm` | 2.4 | Error exponent (`\|Q(w) - w\|^norm`) |
229+
| `maxshrink` | 0.20 | Max fraction to shrink the range |
230+
| `grid` | 20 | Number of grid search steps |
231+
| `patience` | 5 | Early stopping after N steps without improvement |
232+
| `maxgrow` | 0.0 | Max fraction to grow the range beyond observed min/max |
233+
234+
The defaults work well for GPTQ composition. For RTN, increasing `maxshrink` (e.g. 0.80) allows the observer to optimize ranges more aggressively:
235+
236+
```python
237+
scheme.weights.observer_kwargs = {"maxshrink": 0.80}
238+
```
239+
141240
### Questions or Feature Request?
142241

143242
Please open up an issue on `vllm-project/llm-compressor`
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from compressed_tensors.quantization import preset_name_to_scheme
3+
from transformers import AutoModelForCausalLM, AutoTokenizer
4+
5+
from llmcompressor import oneshot
6+
from llmcompressor.modifiers.quantization import QuantizationModifier
7+
from llmcompressor.modifiers.transform.imatrix import IMatrixGatherer
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Meta-Llama-3.1-8B"
11+
12+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
13+
tokenizer = AutoTokenizer.from_pretrained(model_id)
14+
15+
# Select calibration dataset.
16+
DATASET_ID = "open_platypus"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Configure the quantization algorithm to run.
24+
# * trigger a calibration pass with IMatrixGatherer so the observer can collect E[x²]
25+
# * quantize the weights to 4 bit with group size 128
26+
# * use imatrix_mse observer to weight quantization error by channel importance
27+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
28+
scheme.weights.observer = "imatrix_mse"
29+
30+
recipe = [
31+
IMatrixGatherer(ignore=["lm_head"]),
32+
QuantizationModifier(
33+
config_groups={"group_0": scheme},
34+
ignore=["lm_head"],
35+
),
36+
]
37+
38+
# Apply algorithms.
39+
oneshot(
40+
model=model,
41+
dataset=DATASET_ID,
42+
splits={"calibration": "train[:5%]"},
43+
recipe=recipe,
44+
max_seq_length=MAX_SEQUENCE_LENGTH,
45+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
46+
)
47+
48+
# Confirm generations of the quantized model look sane.
49+
print("\n\n")
50+
print("========== SAMPLE GENERATION ==============")
51+
dispatch_model(model)
52+
sample = tokenizer("Hello my name is", return_tensors="pt")
53+
sample = {key: value.to(model.device) for key, value in sample.items()}
54+
output = model.generate(**sample, max_new_tokens=100)
55+
print(tokenizer.decode(output[0]))
56+
print("==========================================\n\n")
57+
58+
# Save to disk compressed.
59+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-imatrix"
60+
model.save_pretrained(SAVE_DIR, save_compressed=True)
61+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def initialize_observer(
8282
observer, base_name=base_name, args=args, module=module
8383
)
8484
module.register_module(f"{base_name}_observer", observer)
85+
observer.init(module)
8586

8687

8788
def call_observer(
@@ -264,6 +265,7 @@ def freeze_module_quantization(module: Module):
264265
for name in ("input", "weight", "output", "q", "k", "v"):
265266
obs_name = f"{name}_observer"
266267
if hasattr(module, obs_name):
268+
getattr(module, obs_name).detach(module)
267269
delattr(module, obs_name)
268270

269271
module.quantization_status = QuantizationStatus.FROZEN
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ruff: noqa
2+
3+
from .base import *
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from typing import List, Union
2+
3+
from compressed_tensors.quantization import disable_quantization
4+
from compressed_tensors.utils import match_named_modules
5+
from pydantic import Field
6+
7+
from llmcompressor.core import Event, State
8+
from llmcompressor.modifiers import Modifier
9+
from llmcompressor.modifiers.quantization.quantization.mixin import QuantizationMixin
10+
11+
__all__ = ["IMatrixGatherer"]
12+
13+
14+
class IMatrixGatherer(Modifier, QuantizationMixin):
15+
"""
16+
Lifecycle trigger for iMatrix importance collection.
17+
18+
Triggers a calibration pass so that ``IMatrixMSEObserver`` can collect
19+
E[x²] via its ``init()`` hook. Does **not** quantize weights — the
20+
actual quantization is done by the subsequent
21+
``QuantizationModifier`` / ``GPTQModifier``.
22+
23+
The observer's ``detach()`` method computes ``_imatrix_importance``
24+
from the accumulated statistics and leaves it on the module for the
25+
next quantization pass to consume.
26+
27+
Example recipe::
28+
29+
recipe:
30+
- IMatrixGatherer:
31+
ignore: ["lm_head"]
32+
- QuantizationModifier:
33+
config_groups:
34+
group_0:
35+
targets: ["Linear"]
36+
weights:
37+
observer: imatrix_mse
38+
39+
Or composed with GPTQ::
40+
41+
recipe:
42+
- IMatrixGatherer:
43+
ignore: ["lm_head"]
44+
- GPTQModifier:
45+
config_groups:
46+
group_0:
47+
targets: ["Linear"]
48+
weights:
49+
observer: imatrix_mse
50+
51+
.. note::
52+
Auto-prepend (inserting the gatherer automatically when
53+
``imatrix_mse`` is detected in a recipe) is planned for a
54+
follow-up PR.
55+
56+
:param scheme: quantization preset used to build the internal config.
57+
Defaults to ``"W4A16"``. The actual bit-width does not matter
58+
because weights are never quantized by this modifier.
59+
:param weight_observer: observer to attach during calibration.
60+
Must be ``"imatrix_mse"`` (default).
61+
:param ignore: layer name patterns to skip (default: ``["lm_head"]``)
62+
:param targets: module types to instrument (default: ``["Linear"]``)
63+
"""
64+
65+
scheme: str = "W4A16"
66+
weight_observer: str = "imatrix_mse"
67+
ignore: List[str] = Field(default_factory=lambda: ["lm_head"])
68+
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
69+
70+
# ------------------------------------------------------------------ #
71+
# Lifecycle
72+
# ------------------------------------------------------------------ #
73+
74+
def on_initialize(self, state: State, **kwargs) -> bool:
75+
QuantizationMixin.initialize_quantization(self, state.model)
76+
return True
77+
78+
def on_start(self, state: State, event: Event, **kwargs):
79+
self.started_ = True
80+
QuantizationMixin.start_calibration(self, state.model)
81+
# Disable quantized forward — we only need observer hooks for E[x²]
82+
state.model.apply(disable_quantization)
83+
84+
def on_end(self, state: State, event: Event, **kwargs):
85+
self.ended_ = True
86+
QuantizationMixin.end_calibration(self, state.model)
87+
# Disable quantized forward so the model is clean for the next modifier
88+
state.model.apply(disable_quantization)
89+
90+
def on_finalize(self, state: State, **kwargs) -> bool:
91+
if not self.ended_:
92+
self.on_end(state, None)
93+
94+
# Clean up importance tensors so they don't end up in checkpoint
95+
for _, module in match_named_modules(
96+
state.model, self.resolved_targets, self.ignore
97+
):
98+
if hasattr(module, "_imatrix_importance"):
99+
del module._imatrix_importance
100+
101+
return True

src/llmcompressor/observers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .moving_base import *
1515
from .min_max import *
1616
from .mse import *
17+
from .imatrix import *

src/llmcompressor/observers/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,24 @@ def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]:
133133
with align_module_device(module):
134134
return getattr(module, f"{self.base_name}_{name}", None)
135135

136+
def init(self, module: torch.nn.Module) -> None:
137+
"""
138+
Called when the observer is attached to a module.
139+
Subclasses can override to register hooks or initialize state.
140+
141+
:param module: the module this observer is being attached to
142+
"""
143+
pass
144+
145+
def detach(self, module: torch.nn.Module) -> None:
146+
"""
147+
Called before the observer is deleted from a module.
148+
Subclasses can override to remove hooks and clean up module attributes.
149+
150+
:param module: the module this observer is being removed from
151+
"""
152+
pass
153+
136154
def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]):
137155
if (
138156
self.args.strategy == QuantizationStrategy.TENSOR_GROUP

0 commit comments

Comments
 (0)