Skip to content

Commit 20f2a0a

Browse files
committed
feat: add iMatrix weighted MSE observer and IMatrixGatherer
- imatrix_mse observer with E[x²] importance weighting - IMatrixGatherer transform using match_named_modules + CPU offload - Unit tests for observer and gatherer - E2E integration tests RFC #2456 Signed-off-by: Gilles Turpin <turpingilles15@gmail.com>
1 parent 0bc916e commit 20f2a0a

File tree

10 files changed

+1519
-0
lines changed

10 files changed

+1519
-0
lines changed

examples/quantization_w4a16/README.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,105 @@ We can see the resulting scores look good!
138138
| | |strict-match | 5|exact_match||0.720|± |0.0285|
139139
```
140140

141+
---
142+
143+
## iMatrix Importance-Weighted Quantization
144+
145+
`imatrix_mse` is an observer that uses per-channel activation importance (E[]) to weight quantization error during range selection. Channels that carry more signal get more careful range optimization.
146+
147+
Two components work together:
148+
- **`IMatrixGatherer`**: collects E[] per input channel via forward pre-hooks during calibration
149+
- **`imatrix_mse` observer**: extends the MSE grid search with importance weighting: `err = sum(importance * |Q(w) - w|^p)`
150+
151+
> See [RFC #2456](https://github.com/vllm-project/llm-compressor/discussions/2456) for the full design discussion.
152+
153+
### Usage
154+
155+
```bash
156+
python3 llama3_imatrix_example.py
157+
```
158+
159+
The simplest setup uses `preset_name_to_scheme` to configure W4A16 and swaps in the `imatrix_mse` observer:
160+
161+
```python
162+
from compressed_tensors.quantization import preset_name_to_scheme
163+
from llmcompressor.modifiers.quantization import QuantizationModifier
164+
from llmcompressor.modifiers.transform.imatrix import IMatrixGatherer
165+
166+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
167+
scheme.weights.observer = "imatrix_mse"
168+
169+
recipe = [
170+
IMatrixGatherer(ignore=["lm_head"]),
171+
QuantizationModifier(
172+
config_groups={"group_0": scheme},
173+
ignore=["lm_head"],
174+
),
175+
]
176+
```
177+
178+
### Composing with GPTQ
179+
180+
iMatrix composes with GPTQ by providing importance-weighted ranges for the Hessian-based rounding:
181+
182+
```python
183+
from llmcompressor.modifiers.gptq import GPTQModifier
184+
185+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
186+
scheme.weights.observer = "imatrix_mse"
187+
188+
recipe = [
189+
IMatrixGatherer(ignore=["lm_head"]),
190+
GPTQModifier(
191+
config_groups={"group_0": scheme},
192+
ignore=["lm_head"],
193+
),
194+
]
195+
```
196+
197+
### Results
198+
199+
W4A16, Llama-3.1-8B, WikiText-2 token-level perplexity (141 chunks x 2048):
200+
201+
**group_size=128:**
202+
203+
| Config | PPL |
204+
|---|---|
205+
| FP16 baseline | 6.24 |
206+
| RTN `memoryless_minmax` | 6.96 |
207+
| RTN `imatrix_mse` | 6.97 |
208+
| GPTQ | 6.89 |
209+
| GPTQ + `imatrix_mse` | 6.82 |
210+
211+
**group_size=32:**
212+
213+
| Config | PPL |
214+
|---|---|
215+
| RTN `memoryless_minmax` | 6.74 |
216+
| RTN `imatrix_mse` | 6.73 |
217+
| GPTQ | 6.70 |
218+
| GPTQ + `imatrix_mse` | 6.66 |
219+
220+
GPTQ + `imatrix_mse` is the best result at both group sizes with default observer settings. iMatrix never degrades quality.
221+
222+
### Observer Parameters
223+
224+
The observer accepts optional `observer_kwargs` for fine-tuning:
225+
226+
| Parameter | Default | Description |
227+
|---|---|---|
228+
| `norm` | 2.4 | Error exponent (`\|Q(w) - w\|^norm`) |
229+
| `maxshrink` | 0.20 | Max fraction to shrink the range |
230+
| `grid` | 20 | Number of grid search steps |
231+
| `patience` | 5 | Early stopping after N steps without improvement |
232+
| `maxgrow` | 0.0 | Max fraction to grow the range beyond observed min/max |
233+
234+
The defaults work well for GPTQ composition. For RTN, increasing `maxshrink` (e.g. 0.80) allows the observer to optimize ranges more aggressively:
235+
236+
```python
237+
scheme.weights.observer_kwargs = {"maxshrink": 0.80}
238+
```
239+
141240
### Questions or Feature Request?
142241

143242
Please open up an issue on `vllm-project/llm-compressor`
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from compressed_tensors.quantization import preset_name_to_scheme
3+
from transformers import AutoModelForCausalLM, AutoTokenizer
4+
5+
from llmcompressor import oneshot
6+
from llmcompressor.modifiers.quantization import QuantizationModifier
7+
from llmcompressor.modifiers.transform.imatrix import IMatrixGatherer
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Meta-Llama-3.1-8B"
11+
12+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
13+
tokenizer = AutoTokenizer.from_pretrained(model_id)
14+
15+
# Select calibration dataset.
16+
DATASET_ID = "open_platypus"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Configure the quantization algorithm to run.
24+
# * collect per-channel importance statistics (E[x²]) with IMatrixGatherer
25+
# * quantize the weights to 4 bit with group size 128
26+
# * use imatrix_mse observer to weight quantization error by channel importance
27+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
28+
scheme.weights.observer = "imatrix_mse"
29+
30+
recipe = [
31+
IMatrixGatherer(ignore=["lm_head"]),
32+
QuantizationModifier(
33+
config_groups={"group_0": scheme},
34+
ignore=["lm_head"],
35+
),
36+
]
37+
38+
# Apply algorithms.
39+
oneshot(
40+
model=model,
41+
dataset=DATASET_ID,
42+
splits={"calibration": "train[:5%]"},
43+
recipe=recipe,
44+
max_seq_length=MAX_SEQUENCE_LENGTH,
45+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
46+
)
47+
48+
# Confirm generations of the quantized model look sane.
49+
print("\n\n")
50+
print("========== SAMPLE GENERATION ==============")
51+
dispatch_model(model)
52+
sample = tokenizer("Hello my name is", return_tensors="pt")
53+
sample = {key: value.to(model.device) for key, value in sample.items()}
54+
output = model.generate(**sample, max_new_tokens=100)
55+
print(tokenizer.decode(output[0]))
56+
print("==========================================\n\n")
57+
58+
# Save to disk compressed.
59+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-imatrix"
60+
model.save_pretrained(SAVE_DIR, save_compressed=True)
61+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ruff: noqa
2+
3+
from .base import *

0 commit comments

Comments
 (0)