Skip to content

Commit 2ed8b25

Browse files
Merge branch 'main' into model-free-ptq-runtime-optimization
2 parents 498d38e + 464d000 commit 2ed8b25

File tree

15 files changed

+1689
-4
lines changed

15 files changed

+1689
-4
lines changed

examples/imatrix/README.md

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# iMatrix Importance-Weighted Quantization
2+
3+
`imatrix_mse` is an observer that uses per-channel activation importance (E[]) to weight quantization error during range selection. Channels that carry more signal get more careful range optimization.
4+
5+
Two components work together:
6+
- **`IMatrixGatherer`**: triggers a calibration pass so the observer can collect importance data
7+
- **`imatrix_mse` observer**: collects E[] per input channel via forward pre-hooks and uses importance weighting in the MSE grid search: `err = sum(importance * |Q(w) - w|^p)`
8+
9+
> See [RFC #2456](https://github.com/vllm-project/llm-compressor/discussions/2456) for the full design discussion.
10+
11+
## Quickstart
12+
13+
```bash
14+
python3 llama3_imatrix_example.py
15+
```
16+
17+
The simplest setup uses `preset_name_to_scheme` to configure W4A16 and swaps in the `imatrix_mse` observer:
18+
19+
```python
20+
from compressed_tensors.quantization import preset_name_to_scheme
21+
from llmcompressor.modifiers.quantization import QuantizationModifier
22+
from llmcompressor.modifiers.transform.imatrix import IMatrixGatherer
23+
24+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
25+
scheme.weights.observer = "imatrix_mse"
26+
27+
recipe = [
28+
IMatrixGatherer(ignore=["lm_head"]),
29+
QuantizationModifier(
30+
config_groups={"group_0": scheme},
31+
ignore=["lm_head"],
32+
),
33+
]
34+
```
35+
36+
## Composing with GPTQ
37+
38+
iMatrix composes with GPTQ by providing importance-weighted ranges for the Hessian-based rounding:
39+
40+
```python
41+
from llmcompressor.modifiers.gptq import GPTQModifier
42+
43+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
44+
scheme.weights.observer = "imatrix_mse"
45+
46+
recipe = [
47+
IMatrixGatherer(ignore=["lm_head"]),
48+
GPTQModifier(
49+
config_groups={"group_0": scheme},
50+
ignore=["lm_head"],
51+
),
52+
]
53+
```
54+
55+
## Results
56+
57+
W4A16, Llama-3.1-8B, group_size=128, WikiText-2 token-level perplexity (141 chunks x 2048):
58+
59+
| Config | PPL |
60+
|---|---|
61+
| FP16 baseline | 6.24 |
62+
| RTN `memoryless_minmax` | 6.96 |
63+
| GPTQ | 6.92 |
64+
| AWQ | 6.89 |
65+
| RTN `imatrix_mse` | 6.85 |
66+
| GPTQ + `imatrix_mse` | 6.83 |
67+
68+
## Questions or Feature Request?
69+
70+
Please open up an issue on `vllm-project/llm-compressor`
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from compressed_tensors.quantization import preset_name_to_scheme
3+
from transformers import AutoModelForCausalLM, AutoTokenizer
4+
5+
from llmcompressor import oneshot
6+
from llmcompressor.modifiers.quantization import QuantizationModifier
7+
from llmcompressor.modifiers.transform.imatrix import IMatrixGatherer
8+
9+
# Select model and load it.
10+
model_id = "meta-llama/Meta-Llama-3.1-8B"
11+
12+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
13+
tokenizer = AutoTokenizer.from_pretrained(model_id)
14+
15+
# Select calibration dataset.
16+
DATASET_ID = "open_platypus"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Configure the quantization algorithm to run.
24+
# * trigger a calibration pass with IMatrixGatherer so the observer can collect E[x²]
25+
# * quantize the weights to 4 bit with group size 128
26+
# * use imatrix_mse observer to weight quantization error by channel importance
27+
scheme = preset_name_to_scheme("W4A16", ["Linear"])
28+
scheme.weights.observer = "imatrix_mse"
29+
30+
recipe = [
31+
IMatrixGatherer(ignore=["lm_head"]),
32+
QuantizationModifier(
33+
config_groups={"group_0": scheme},
34+
ignore=["lm_head"],
35+
),
36+
]
37+
38+
# Apply algorithms.
39+
oneshot(
40+
model=model,
41+
dataset=DATASET_ID,
42+
splits={"calibration": "train[:5%]"},
43+
recipe=recipe,
44+
max_seq_length=MAX_SEQUENCE_LENGTH,
45+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
46+
)
47+
48+
# Confirm generations of the quantized model look sane.
49+
print("\n\n")
50+
print("========== SAMPLE GENERATION ==============")
51+
dispatch_model(model)
52+
sample = tokenizer("Hello my name is", return_tensors="pt")
53+
sample = {key: value.to(model.device) for key, value in sample.items()}
54+
output = model.generate(**sample, max_new_tokens=100)
55+
print(tokenizer.decode(output[0]))
56+
print("==========================================\n\n")
57+
58+
# Save to disk compressed.
59+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-imatrix"
60+
model.save_pretrained(SAVE_DIR, save_compressed=True)
61+
tokenizer.save_pretrained(SAVE_DIR)

examples/quantization_w4a16/README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,3 @@ We can see the resulting scores look good!
138138
| | |strict-match | 5|exact_match||0.720|± |0.0285|
139139
```
140140

141-
### Questions or Feature Request?
142-
143-
Please open up an issue on `vllm-project/llm-compressor`

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def initialize_observer(
8282
observer, base_name=base_name, args=args, module=module
8383
)
8484
module.register_module(f"{base_name}_observer", observer)
85+
observer.attach(module)
8586

8687

8788
def call_observer(
@@ -264,6 +265,7 @@ def freeze_module_quantization(module: Module):
264265
for name in ("input", "weight", "output", "q", "k", "v"):
265266
obs_name = f"{name}_observer"
266267
if hasattr(module, obs_name):
268+
getattr(module, obs_name).detach(module)
267269
delattr(module, obs_name)
268270

269271
module.quantization_status = QuantizationStatus.FROZEN

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
is_preset_scheme,
2020
preset_name_to_scheme,
2121
)
22+
from compressed_tensors.quantization.utils import KV_CACHE_TARGETS
2223
from compressed_tensors.utils import match_named_modules, update_offload_parameter
2324
from pydantic import Field, PrivateAttr, field_validator
2425
from torch.utils.hooks import RemovableHandle
@@ -208,7 +209,7 @@ def resolved_targets(self) -> Set[str]:
208209

209210
if self.resolved_config.kv_cache_scheme is not None:
210211
# TODO: decouple reliance on this regex for matching attention
211-
targets.add("re:.*self_attn$")
212+
targets.update(KV_CACHE_TARGETS)
212213

213214
return targets
214215

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ruff: noqa
2+
3+
from .base import *
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from typing import Dict, List, Union
2+
3+
from compressed_tensors.quantization import QuantizationArgs
4+
from compressed_tensors.utils import match_named_modules
5+
from pydantic import Field
6+
7+
from llmcompressor.core import Event, EventType, State
8+
from llmcompressor.modifiers import Modifier
9+
from llmcompressor.observers.base import Observer
10+
11+
__all__ = ["IMatrixGatherer"]
12+
13+
14+
class IMatrixGatherer(Modifier):
15+
"""
16+
Lifecycle trigger for iMatrix importance collection.
17+
18+
Triggers a calibration pass so that ``IMatrixMSEObserver`` can collect
19+
E[x²] via its ``attach()`` hook. Does **not** quantize weights — the
20+
actual quantization is done by the subsequent
21+
``QuantizationModifier`` / ``GPTQModifier``.
22+
23+
The observer's ``detach()`` method computes ``_imatrix_importance``
24+
from the accumulated statistics and leaves it on the module for the
25+
next quantization pass to consume.
26+
27+
Example recipe::
28+
29+
recipe:
30+
- IMatrixGatherer:
31+
ignore: ["lm_head"]
32+
- QuantizationModifier:
33+
config_groups:
34+
group_0:
35+
targets: ["Linear"]
36+
weights:
37+
observer: imatrix_mse
38+
39+
Or composed with GPTQ::
40+
41+
recipe:
42+
- IMatrixGatherer:
43+
ignore: ["lm_head"]
44+
- GPTQModifier:
45+
config_groups:
46+
group_0:
47+
targets: ["Linear"]
48+
weights:
49+
observer: imatrix_mse
50+
51+
.. note::
52+
Auto-prepend (inserting the gatherer automatically when
53+
``imatrix_mse`` is detected in a recipe) is planned for a
54+
follow-up PR.
55+
56+
:param targets: module types to instrument (default: ``["Linear"]``)
57+
:param ignore: layer name patterns to skip (default: ``["lm_head"]``)
58+
:param weight_observer: observer to attach during calibration.
59+
Must be ``"imatrix_mse"`` (default).
60+
"""
61+
62+
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
63+
ignore: List[str] = Field(default_factory=lambda: ["lm_head"])
64+
weight_observer: str = "imatrix_mse"
65+
66+
# ------------------------------------------------------------------ #
67+
# Lifecycle
68+
# ------------------------------------------------------------------ #
69+
70+
def on_initialize(self, state: State, **kwargs) -> bool:
71+
"""
72+
Attach iMatrix observers to target modules for E[x²] collection
73+
"""
74+
self._resolved_targets = (
75+
self.targets if isinstance(self.targets, list) else [self.targets]
76+
)
77+
self._observers: Dict[str, Observer] = {}
78+
79+
# Minimal QuantizationArgs — only used to instantiate the observer,
80+
# no quantization config is applied to the model.
81+
observer_args = QuantizationArgs(observer=self.weight_observer)
82+
83+
for name, module in match_named_modules(
84+
state.model, self._resolved_targets, self.ignore
85+
):
86+
observer = Observer.load_from_registry(
87+
self.weight_observer,
88+
base_name="weight",
89+
args=observer_args,
90+
module=module,
91+
)
92+
module.register_module("weight_observer", observer)
93+
observer.attach(module)
94+
self._observers[name] = observer
95+
96+
return True
97+
98+
def on_start(self, state: State, event: Event, **kwargs):
99+
self.started_ = True
100+
101+
def on_event(self, state: State, event: Event, **kwargs):
102+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
103+
if not self.started_:
104+
self.on_start(state, None)
105+
106+
if event.type_ == EventType.CALIBRATION_EPOCH_END:
107+
if not self.ended_:
108+
self.on_end(state, None)
109+
110+
def on_end(self, state: State, event: Event, **kwargs):
111+
self.ended_ = True
112+
for name, observer in self._observers.items():
113+
module = observer.module() if observer.module is not None else None
114+
if module is not None:
115+
observer.detach(module)
116+
if hasattr(module, "weight_observer"):
117+
delattr(module, "weight_observer")
118+
self._observers.clear()
119+
120+
def on_finalize(self, state: State, **kwargs) -> bool:
121+
"""
122+
Clean up importance tensors so they don't end up in the checkpoint
123+
"""
124+
if not self.ended_:
125+
self.on_end(state, None)
126+
127+
# Clean up importance tensors so they don't end up in checkpoint
128+
for _, module in match_named_modules(
129+
state.model, self._resolved_targets, self.ignore
130+
):
131+
if hasattr(module, "_imatrix_importance"):
132+
del module._imatrix_importance
133+
134+
return True

src/llmcompressor/observers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .moving_base import *
1515
from .min_max import *
1616
from .mse import *
17+
from .imatrix import *

src/llmcompressor/observers/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,24 @@ def recompute_qparams(self) -> Optional[ScaleZpTuple]:
194194
global_scale=global_scale,
195195
)
196196

197+
def attach(self, module: torch.nn.Module) -> None:
198+
"""
199+
Called when the observer is attached to a module.
200+
Subclasses can override to register hooks or initialize state.
201+
202+
:param module: the module this observer is being attached to
203+
"""
204+
pass
205+
206+
def detach(self, module: torch.nn.Module) -> None:
207+
"""
208+
Called before the observer is deleted from a module.
209+
Subclasses can override to remove hooks and clean up module attributes.
210+
211+
:param module: the module this observer is being removed from
212+
"""
213+
pass
214+
197215
def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]):
198216
if (
199217
self.args.strategy == QuantizationStrategy.TENSOR_GROUP

0 commit comments

Comments
 (0)