Skip to content

Commit 8429a6f

Browse files
committed
feat: add iMatrix weighted MSE observer and IMatrixGatherer
Signed-off-by: Gilles Turpin <turpingilles15@gmail.com>
1 parent 5ae2e14 commit 8429a6f

File tree

12 files changed

+1596
-0
lines changed

12 files changed

+1596
-0
lines changed

examples/quantization_w4a16/README.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,105 @@ We can see the resulting scores look good!
138138
| | |strict-match | 5|exact_match||0.720|± |0.0285|
139139
```
140140

141+
---
142+
143+
## iMatrix Importance-Weighted Quantization
144+
145+
`imatrix_mse` is an observer that uses per-channel activation importance (E[]) to weight quantization error during range selection. Channels that carry more signal get more careful range optimization.
146+
147+
Two components work together:
148+
- **`IMatrixGatherer`**: triggers a calibration pass so the observer can collect importance data
149+
- **`imatrix_mse` observer**: collects E[] per input channel via forward pre-hooks and uses importance weighting in the MSE grid search: `err = sum(importance * |Q(w) - w|^p)`
150+
151+
> See [RFC #2456](https://github.com/vllm-project/llm-compressor/discussions/2456) for the full design discussion.
152+
153+
### Usage
154+
155+
```bash
156+
python3 llama3_imatrix_example.py
157+
```
158+
159+
The simplest setup uses `preset_name_to_scheme` to configure W4A16 and swaps in the `imatrix_mse` observer:
160+
161+
```python
162+
from compressed_tensors.quantization import preset_name_to_scheme
163+
from llmcompressor.modifiers.quantization import QuantizationModifier
164+
from llmcompressor.modifiers.transform.imatrix import IMatrixGatherer
165+
166+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
167+
scheme.weights.observer = "imatrix_mse"
168+
169+
recipe = [
170+
IMatrixGatherer(ignore=["lm_head"]),
171+
QuantizationModifier(
172+
config_groups={"group_0": scheme},
173+
ignore=["lm_head"],
174+
),
175+
]
176+
```
177+
178+
### Composing with GPTQ
179+
180+
iMatrix composes with GPTQ by providing importance-weighted ranges for the Hessian-based rounding:
181+
182+
```python
183+
from llmcompressor.modifiers.gptq import GPTQModifier
184+
185+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
186+
scheme.weights.observer = "imatrix_mse"
187+
188+
recipe = [
189+
IMatrixGatherer(ignore=["lm_head"]),
190+
GPTQModifier(
191+
config_groups={"group_0": scheme},
192+
ignore=["lm_head"],
193+
),
194+
]
195+
```
196+
197+
### Results
198+
199+
W4A16, Llama-3.1-8B, WikiText-2 token-level perplexity (141 chunks x 2048):
200+
201+
**group_size=128:**
202+
203+
| Config | PPL |
204+
|---|---|
205+
| FP16 baseline | 6.24 |
206+
| RTN `memoryless_minmax` | 6.96 |
207+
| RTN `imatrix_mse` | 6.97 |
208+
| GPTQ | 6.89 |
209+
| GPTQ + `imatrix_mse` | 6.82 |
210+
211+
**group_size=32:**
212+
213+
| Config | PPL |
214+
|---|---|
215+
| RTN `memoryless_minmax` | 6.74 |
216+
| RTN `imatrix_mse` | 6.73 |
217+
| GPTQ | 6.70 |
218+
| GPTQ + `imatrix_mse` | 6.66 |
219+
220+
GPTQ + `imatrix_mse` is the best result at both group sizes with default observer settings. iMatrix never degrades quality.
221+
222+
### Observer Parameters
223+
224+
The observer accepts optional `observer_kwargs` for fine-tuning:
225+
226+
| Parameter | Default | Description |
227+
|---|---|---|
228+
| `norm` | 2.4 | Error exponent (`\|Q(w) - w\|^norm`) |
229+
| `maxshrink` | 0.20 | Max fraction to shrink the range |
230+
| `grid` | 20 | Number of grid search steps |
231+
| `patience` | 5 | Early stopping after N steps without improvement |
232+
| `maxgrow` | 0.0 | Max fraction to grow the range beyond observed min/max |
233+
234+
The defaults work well for GPTQ composition. For RTN, increasing `maxshrink` (e.g. 0.80) allows the observer to optimize ranges more aggressively:
235+
236+
```python
237+
scheme.weights.observer_kwargs = {"maxshrink": 0.80}
238+
```
239+
141240
### Questions or Feature Request?
142241

143242
Please open up an issue on `vllm-project/llm-compressor`
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from compressed_tensors.quantization import preset_name_to_scheme
3+
from transformers import AutoModelForCausalLM, AutoTokenizer
4+
5+
from llmcompressor import oneshot
6+
from llmcompressor.modifiers.quantization import QuantizationModifier
7+
from llmcompressor.modifiers.transform.imatrix import IMatrixGatherer
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Meta-Llama-3.1-8B"
11+
12+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
13+
tokenizer = AutoTokenizer.from_pretrained(model_id)
14+
15+
# Select calibration dataset.
16+
DATASET_ID = "open_platypus"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Configure the quantization algorithm to run.
24+
# * trigger a calibration pass with IMatrixGatherer so the observer can collect E[x²]
25+
# * quantize the weights to 4 bit with group size 128
26+
# * use imatrix_mse observer to weight quantization error by channel importance
27+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
28+
scheme.weights.observer = "imatrix_mse"
29+
30+
recipe = [
31+
IMatrixGatherer(ignore=["lm_head"]),
32+
QuantizationModifier(
33+
config_groups={"group_0": scheme},
34+
ignore=["lm_head"],
35+
),
36+
]
37+
38+
# Apply algorithms.
39+
oneshot(
40+
model=model,
41+
dataset=DATASET_ID,
42+
splits={"calibration": "train[:5%]"},
43+
recipe=recipe,
44+
max_seq_length=MAX_SEQUENCE_LENGTH,
45+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
46+
)
47+
48+
# Confirm generations of the quantized model look sane.
49+
print("\n\n")
50+
print("========== SAMPLE GENERATION ==============")
51+
dispatch_model(model)
52+
sample = tokenizer("Hello my name is", return_tensors="pt")
53+
sample = {key: value.to(model.device) for key, value in sample.items()}
54+
output = model.generate(**sample, max_new_tokens=100)
55+
print(tokenizer.decode(output[0]))
56+
print("==========================================\n\n")
57+
58+
# Save to disk compressed.
59+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-imatrix"
60+
model.save_pretrained(SAVE_DIR, save_compressed=True)
61+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def initialize_observer(
8282
observer, base_name=base_name, args=args, module=module
8383
)
8484
module.register_module(f"{base_name}_observer", observer)
85+
observer.init(module)
8586

8687

8788
def call_observer(
@@ -264,6 +265,7 @@ def freeze_module_quantization(module: Module):
264265
for name in ("input", "weight", "output", "q", "k", "v"):
265266
obs_name = f"{name}_observer"
266267
if hasattr(module, obs_name):
268+
getattr(module, obs_name).detach(module)
267269
delattr(module, obs_name)
268270

269271
module.quantization_status = QuantizationStatus.FROZEN
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ruff: noqa
2+
3+
from .base import *
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from typing import Dict, List, Union
2+
3+
from compressed_tensors.quantization import QuantizationArgs
4+
from compressed_tensors.utils import match_named_modules
5+
from pydantic import Field
6+
7+
from llmcompressor.core import Event, EventType, State
8+
from llmcompressor.modifiers import Modifier
9+
from llmcompressor.observers.base import Observer
10+
11+
__all__ = ["IMatrixGatherer"]
12+
13+
14+
class IMatrixGatherer(Modifier):
15+
"""
16+
Lifecycle trigger for iMatrix importance collection.
17+
18+
Triggers a calibration pass so that ``IMatrixMSEObserver`` can collect
19+
E[x²] via its ``init()`` hook. Does **not** quantize weights — the
20+
actual quantization is done by the subsequent
21+
``QuantizationModifier`` / ``GPTQModifier``.
22+
23+
The observer's ``detach()`` method computes ``_imatrix_importance``
24+
from the accumulated statistics and leaves it on the module for the
25+
next quantization pass to consume.
26+
27+
Example recipe::
28+
29+
recipe:
30+
- IMatrixGatherer:
31+
ignore: ["lm_head"]
32+
- QuantizationModifier:
33+
config_groups:
34+
group_0:
35+
targets: ["Linear"]
36+
weights:
37+
observer: imatrix_mse
38+
39+
Or composed with GPTQ::
40+
41+
recipe:
42+
- IMatrixGatherer:
43+
ignore: ["lm_head"]
44+
- GPTQModifier:
45+
config_groups:
46+
group_0:
47+
targets: ["Linear"]
48+
weights:
49+
observer: imatrix_mse
50+
51+
.. note::
52+
Auto-prepend (inserting the gatherer automatically when
53+
``imatrix_mse`` is detected in a recipe) is planned for a
54+
follow-up PR.
55+
56+
:param targets: module types to instrument (default: ``["Linear"]``)
57+
:param ignore: layer name patterns to skip (default: ``["lm_head"]``)
58+
:param weight_observer: observer to attach during calibration.
59+
Must be ``"imatrix_mse"`` (default).
60+
"""
61+
62+
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
63+
ignore: List[str] = Field(default_factory=lambda: ["lm_head"])
64+
weight_observer: str = "imatrix_mse"
65+
66+
# ------------------------------------------------------------------ #
67+
# Lifecycle
68+
# ------------------------------------------------------------------ #
69+
70+
def on_initialize(self, state: State, **kwargs) -> bool:
71+
"""
72+
Attach iMatrix observers to target modules for E[x²] collection
73+
"""
74+
self._resolved_targets = (
75+
self.targets if isinstance(self.targets, list) else [self.targets]
76+
)
77+
self._observers: Dict[str, Observer] = {}
78+
79+
# Minimal QuantizationArgs — only used to instantiate the observer,
80+
# no quantization config is applied to the model.
81+
observer_args = QuantizationArgs(observer=self.weight_observer)
82+
83+
for name, module in match_named_modules(
84+
state.model, self._resolved_targets, self.ignore
85+
):
86+
observer = Observer.load_from_registry(
87+
self.weight_observer,
88+
base_name="weight",
89+
args=observer_args,
90+
module=module,
91+
)
92+
module.register_module("weight_observer", observer)
93+
observer.init(module)
94+
self._observers[name] = observer
95+
96+
return True
97+
98+
def on_start(self, state: State, event: Event, **kwargs):
99+
self.started_ = True
100+
101+
def on_event(self, state: State, event: Event, **kwargs):
102+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
103+
if not self.started_:
104+
self.on_start(state, None)
105+
106+
if event.type_ == EventType.CALIBRATION_EPOCH_END:
107+
if not self.ended_:
108+
self.on_end(state, None)
109+
110+
def on_end(self, state: State, event: Event, **kwargs):
111+
self.ended_ = True
112+
for name, observer in self._observers.items():
113+
module = observer.module() if observer.module is not None else None
114+
if module is not None:
115+
observer.detach(module)
116+
if hasattr(module, "weight_observer"):
117+
delattr(module, "weight_observer")
118+
self._observers.clear()
119+
120+
def on_finalize(self, state: State, **kwargs) -> bool:
121+
"""
122+
Clean up importance tensors so they don't end up in the checkpoint
123+
"""
124+
if not self.ended_:
125+
self.on_end(state, None)
126+
127+
# Clean up importance tensors so they don't end up in checkpoint
128+
for _, module in match_named_modules(
129+
state.model, self._resolved_targets, self.ignore
130+
):
131+
if hasattr(module, "_imatrix_importance"):
132+
del module._imatrix_importance
133+
134+
return True

src/llmcompressor/observers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .moving_base import *
1515
from .min_max import *
1616
from .mse import *
17+
from .imatrix import *

src/llmcompressor/observers/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,24 @@ def _get_module_param(self, name: str) -> Optional[torch.nn.Parameter]:
133133
with align_module_device(module):
134134
return getattr(module, f"{self.base_name}_{name}", None)
135135

136+
def init(self, module: torch.nn.Module) -> None:
137+
"""
138+
Called when the observer is attached to a module.
139+
Subclasses can override to register hooks or initialize state.
140+
141+
:param module: the module this observer is being attached to
142+
"""
143+
pass
144+
145+
def detach(self, module: torch.nn.Module) -> None:
146+
"""
147+
Called before the observer is deleted from a module.
148+
Subclasses can override to remove hooks and clean up module attributes.
149+
150+
:param module: the module this observer is being removed from
151+
"""
152+
pass
153+
136154
def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]):
137155
if (
138156
self.args.strategy == QuantizationStrategy.TENSOR_GROUP

0 commit comments

Comments
 (0)