Skip to content

Commit adef032

Browse files
committed
feat: add iMatrix weighted MSE observer and IMatrixGatherer
- imatrix_mse observer with E[x²] importance weighting - IMatrixGatherer transform using match_named_modules + CPU offload - Unit tests for observer and gatherer - E2E integration tests RFC #2456 Signed-off-by: Gilles Turpin <turpingilles15@gmail.com>
1 parent 822668a commit adef032

File tree

9 files changed

+1528
-0
lines changed

9 files changed

+1528
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
from llmcompressor.modifiers.transform.imatrix import IMatrixGatherer
7+
8+
# Select model and load it.
9+
model_id = "meta-llama/Meta-Llama-3.1-8B"
10+
11+
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(model_id)
13+
14+
# Select calibration dataset.
15+
DATASET_ID = "open_platypus"
16+
17+
# Select number of samples. 512 samples is a good place to start.
18+
# Increasing the number of samples can improve accuracy.
19+
NUM_CALIBRATION_SAMPLES = 512
20+
MAX_SEQUENCE_LENGTH = 2048
21+
22+
# Configure the quantization algorithm to run.
23+
# * collect per-channel importance statistics (E[x²]) with IMatrixGatherer
24+
# * quantize the weights to 4 bit with group size 128
25+
# * use imatrix_mse observer to weight quantization error by channel importance
26+
recipe = [
27+
IMatrixGatherer(ignore=["lm_head"]),
28+
QuantizationModifier(
29+
config_groups={
30+
"group_0": {
31+
"targets": ["Linear"],
32+
"weights": {
33+
"num_bits": 4,
34+
"type": "int",
35+
"symmetric": True,
36+
"strategy": "group",
37+
"group_size": 128,
38+
"observer": "imatrix_mse",
39+
"observer_kwargs": {
40+
"norm": 2.4,
41+
"maxshrink": 0.20,
42+
"grid": 20,
43+
},
44+
},
45+
}
46+
},
47+
ignore=["lm_head"],
48+
),
49+
]
50+
51+
# Apply algorithms.
52+
oneshot(
53+
model=model,
54+
dataset=DATASET_ID,
55+
splits={"calibration": "train[:5%]"},
56+
recipe=recipe,
57+
max_seq_length=MAX_SEQUENCE_LENGTH,
58+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
59+
)
60+
61+
# Confirm generations of the quantized model look sane.
62+
print("\n\n")
63+
print("========== SAMPLE GENERATION ==============")
64+
dispatch_model(model)
65+
sample = tokenizer("Hello my name is", return_tensors="pt")
66+
sample = {key: value.to(model.device) for key, value in sample.items()}
67+
output = model.generate(**sample, max_new_tokens=100)
68+
print(tokenizer.decode(output[0]))
69+
print("==========================================\n\n")
70+
71+
# Save to disk compressed.
72+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-imatrix"
73+
model.save_pretrained(SAVE_DIR, save_compressed=True)
74+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ruff: noqa
2+
3+
from .base import *
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
"""
2+
IMatrixGatherer: collects importance-weighted activation statistics
3+
(E[x²] per input channel) on Linear modules via forward pre-hooks.
4+
5+
Stores ``module._imatrix_importance`` as a 1-D float32 tensor of
6+
shape ``(in_features,)`` on each target module. Does **not** quantize
7+
or modify weights in any way.
8+
9+
The downstream ``imatrix_mse`` observer reads this attribute during
10+
its grid search to weight quantization error by channel importance.
11+
12+
Example recipe::
13+
14+
recipe:
15+
- IMatrixGatherer:
16+
ignore: ["lm_head"]
17+
- QuantizationModifier:
18+
config_groups:
19+
group_0:
20+
targets: ["Linear"]
21+
weights:
22+
observer: imatrix_mse
23+
24+
Or composed with AWQ and GPTQ::
25+
26+
recipe:
27+
- AWQModifier(...)
28+
- IMatrixGatherer:
29+
ignore: ["lm_head"]
30+
- GPTQModifier:
31+
config_groups:
32+
group_0:
33+
targets: ["Linear"]
34+
weights:
35+
observer: imatrix_mse
36+
37+
.. note::
38+
Auto-prepend (inserting the gatherer automatically when
39+
``imatrix_mse`` is detected in a recipe) is planned for a
40+
follow-up PR.
41+
42+
.. note::
43+
Unlike AWQModifier, this gatherer does not use IntermediatesCache
44+
because it only stores a single accumulated 1-D tensor per layer
45+
(not full batch activations). A simple CPU offload at
46+
CALIBRATION_EPOCH_END is sufficient.
47+
"""
48+
49+
from typing import Dict, List, Optional, Union
50+
51+
import torch
52+
from compressed_tensors.utils import match_named_modules
53+
from loguru import logger
54+
from pydantic import Field
55+
from torch.nn import Module
56+
57+
from llmcompressor.core import Event, EventType, State
58+
from llmcompressor.modifiers import Modifier
59+
60+
__all__ = ["IMatrixGatherer"]
61+
62+
63+
class IMatrixGatherer(Modifier):
64+
"""
65+
Collects importance-weighted activation statistics (E[x²])
66+
for each targeted module via forward pre-hooks.
67+
68+
Stores ``module._imatrix_importance`` as a 1-D float32 tensor
69+
of shape ``(in_features,)`` on each target module.
70+
71+
Does NOT quantize. Does NOT modify weights.
72+
73+
Statistics are kept on GPU during calibration for speed, then
74+
offloaded to CPU at CALIBRATION_EPOCH_END to free GPU memory
75+
before quantization begins.
76+
77+
:param ignore: layer name patterns to skip (default: ``["lm_head"]``)
78+
:param targets: module types to instrument (default: ``["Linear"]``)
79+
"""
80+
81+
ignore: Union[str, List[str]] = Field(
82+
default_factory=lambda: ["lm_head"],
83+
)
84+
targets: Union[str, List[str]] = Field(
85+
default_factory=lambda: ["Linear"],
86+
)
87+
88+
# -- internal state (excluded from serialisation) --
89+
_target_names: Optional[List[str]] = None
90+
_sums: Optional[Dict[str, torch.Tensor]] = None
91+
_counts: Optional[Dict[str, int]] = None
92+
93+
# ------------------------------------------------------------------ #
94+
# Lifecycle
95+
# ------------------------------------------------------------------ #
96+
97+
def on_initialize(self, state: State, **kwargs) -> bool:
98+
if self.end and self.end != -1:
99+
raise ValueError(
100+
f"{self.__class__.__name__} can only be applied "
101+
f"during one-shot. Expected end to be None or "
102+
f"-1, got {self.end}"
103+
)
104+
if self.start and self.start != -1:
105+
raise ValueError(
106+
f"{self.__class__.__name__} can only be applied "
107+
f"during one-shot. Expected start to be None "
108+
f"or -1, got {self.start}"
109+
)
110+
111+
self._resolve_targets(state.model)
112+
return True
113+
114+
def on_start(self, state: State, event: Event, **kwargs):
115+
self.started_ = True
116+
self._register_accumulation_hooks(state.model)
117+
118+
def on_event(self, state: State, event: Event, **kwargs):
119+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
120+
if not self.started_:
121+
self.on_start(state, None)
122+
123+
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
124+
self._compute_and_attach(state.model)
125+
126+
if event.type_ == EventType.CALIBRATION_EPOCH_END:
127+
self._compute_and_attach(state.model, offload_to_cpu=True)
128+
129+
if not self.ended_:
130+
self.on_end(state, None)
131+
132+
def on_end(self, state: State, event: Event, **kwargs):
133+
self.ended_ = True
134+
self.remove_hooks()
135+
136+
def on_finalize(self, state: State, **kwargs) -> bool:
137+
if not self.ended_:
138+
self.on_end(state, None)
139+
140+
self._sums = None
141+
self._counts = None
142+
self._target_names = None
143+
return True
144+
145+
# ------------------------------------------------------------------ #
146+
# Target resolution
147+
# ------------------------------------------------------------------ #
148+
149+
def _resolve_targets(self, model: Module):
150+
"""Identify target modules using compressed_tensors matching."""
151+
self._target_names = []
152+
self._sums = {}
153+
self._counts = {}
154+
155+
for name, module in match_named_modules(model, self.targets, self.ignore):
156+
if not hasattr(module, "in_features"):
157+
continue
158+
159+
self._target_names.append(name)
160+
self._sums[name] = torch.zeros(module.in_features, dtype=torch.float32)
161+
self._counts[name] = 0
162+
163+
logger.info(f"IMatrixGatherer: targeting {len(self._target_names)}" f" modules")
164+
165+
# ------------------------------------------------------------------ #
166+
# Hook registration
167+
# ------------------------------------------------------------------ #
168+
169+
def _register_accumulation_hooks(self, model: Module):
170+
"""Attach a forward-pre hook to every target module."""
171+
172+
def _create_hook_fn(layer_name: str):
173+
"""Closure captures layer_name."""
174+
175+
def _hook(module: Module, args):
176+
x = args[0] if not isinstance(args, torch.Tensor) else args
177+
if isinstance(x, tuple):
178+
x = x[0]
179+
if not isinstance(x, torch.Tensor):
180+
return
181+
182+
# Per-token accumulation
183+
x_f = x.detach().float()
184+
n_tokens = x_f[..., 0].numel()
185+
token_sum = x_f.pow(2).sum(dim=list(range(x_f.dim() - 1)))
186+
187+
device = self._sums[layer_name].device
188+
if device != token_sum.device:
189+
self._sums[layer_name] = self._sums[layer_name].to(token_sum.device)
190+
191+
self._sums[layer_name].add_(token_sum)
192+
self._counts[layer_name] += n_tokens
193+
194+
return _hook
195+
196+
for name, module in match_named_modules(model, self.targets, self.ignore):
197+
if name in self._sums:
198+
self.register_hook(
199+
module,
200+
_create_hook_fn(name),
201+
"forward_pre",
202+
)
203+
204+
# ------------------------------------------------------------------ #
205+
# Compute & attach
206+
# ------------------------------------------------------------------ #
207+
208+
def _compute_and_attach(self, model: Module, offload_to_cpu: bool = False):
209+
"""
210+
Compute E[x²] and store on each module.
211+
212+
:param model: model whose modules receive importance data
213+
:param offload_to_cpu: if True, move importance tensors to CPU
214+
after attaching. Set at CALIBRATION_EPOCH_END to free
215+
GPU memory before quantization.
216+
"""
217+
attached = 0
218+
for name, module in match_named_modules(model, self.targets, self.ignore):
219+
if name not in self._sums:
220+
continue
221+
222+
count = self._counts[name]
223+
if count == 0:
224+
continue
225+
226+
importance = self._sums[name] / count
227+
228+
if offload_to_cpu:
229+
importance = importance.to("cpu")
230+
# also free the accumulator
231+
del self._sums[name]
232+
233+
module._imatrix_importance = importance
234+
235+
attached += 1
236+
logger.debug(
237+
f"iMatrix {name}: "
238+
f"mean={importance.mean():.4f}, "
239+
f"max={importance.max():.4f}, "
240+
f"ratio="
241+
f"{importance.max() / (importance.mean() + 1e-10):.1f}"
242+
)
243+
244+
logger.info(
245+
f"IMatrixGatherer: attached importance to "
246+
f"{attached} modules" + (" (offloaded to CPU)" if offload_to_cpu else "")
247+
)

src/llmcompressor/observers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .moving_base import *
1515
from .min_max import *
1616
from .mse import *
17+
from .imatrix import *

0 commit comments

Comments
 (0)