Skip to content

Commit b982e23

Browse files
committed
feat: add iMatrix weighted MSE observer and IMatrixGatherer
- imatrix_mse observer with E[x²] importance weighting - IMatrixGatherer transform using match_named_modules + CPU offload - Unit tests for observer and gatherer - E2E integration tests RFC #2456 Signed-off-by: Gilles Turpin <turpingilles15@gmail.com>
1 parent 370c04c commit b982e23

File tree

8 files changed

+1391
-0
lines changed

8 files changed

+1391
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ruff: noqa
2+
3+
from .base import *
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
"""
2+
IMatrixGatherer: collects importance-weighted activation statistics
3+
(E[x²] per input channel) on Linear modules via forward pre-hooks.
4+
5+
Stores ``module._imatrix_importance`` as a 1-D float32 tensor of
6+
shape ``(in_features,)`` on each target module. Does **not** quantize
7+
or modify weights in any way.
8+
9+
The downstream ``imatrix_mse`` observer reads this attribute during
10+
its grid search to weight quantization error by channel importance.
11+
12+
Example recipe::
13+
14+
recipe:
15+
- IMatrixGatherer:
16+
ignore: ["lm_head"]
17+
- QuantizationModifier:
18+
config_groups:
19+
group_0:
20+
targets: ["Linear"]
21+
weights:
22+
observer: imatrix_mse
23+
24+
Or composed with AWQ and GPTQ::
25+
26+
recipe:
27+
- AWQModifier(...)
28+
- IMatrixGatherer:
29+
ignore: ["lm_head"]
30+
- GPTQModifier:
31+
config_groups:
32+
group_0:
33+
targets: ["Linear"]
34+
weights:
35+
observer: imatrix_mse
36+
37+
.. note::
38+
Auto-prepend (inserting the gatherer automatically when
39+
``imatrix_mse`` is detected in a recipe) is planned for a
40+
follow-up PR.
41+
42+
.. note::
43+
Unlike AWQModifier, this gatherer does not use IntermediatesCache
44+
because it only stores a single accumulated 1-D tensor per layer
45+
(not full batch activations). A simple CPU offload at
46+
CALIBRATION_EPOCH_END is sufficient.
47+
"""
48+
49+
from typing import Dict, List, Optional, Union
50+
51+
import torch
52+
from compressed_tensors.utils import match_named_modules
53+
from loguru import logger
54+
from pydantic import Field
55+
from torch.nn import Module
56+
57+
from llmcompressor.core import Event, EventType, State
58+
from llmcompressor.modifiers import Modifier
59+
60+
__all__ = ["IMatrixGatherer"]
61+
62+
63+
class IMatrixGatherer(Modifier):
64+
"""
65+
Collects importance-weighted activation statistics (E[x²])
66+
for each targeted module via forward pre-hooks.
67+
68+
Stores ``module._imatrix_importance`` as a 1-D float32 tensor
69+
of shape ``(in_features,)`` on each target module.
70+
71+
Does NOT quantize. Does NOT modify weights.
72+
73+
Statistics are kept on GPU during calibration for speed, then
74+
offloaded to CPU at CALIBRATION_EPOCH_END to free GPU memory
75+
before quantization begins.
76+
77+
:param ignore: layer name patterns to skip (default: ``["lm_head"]``)
78+
:param targets: module types to instrument (default: ``["Linear"]``)
79+
"""
80+
81+
ignore: Union[str, List[str]] = Field(
82+
default_factory=lambda: ["lm_head"],
83+
)
84+
targets: Union[str, List[str]] = Field(
85+
default_factory=lambda: ["Linear"],
86+
)
87+
88+
# -- internal state (excluded from serialisation) --
89+
_target_names: Optional[List[str]] = None
90+
_sums: Optional[Dict[str, torch.Tensor]] = None
91+
_counts: Optional[Dict[str, int]] = None
92+
93+
# ------------------------------------------------------------------ #
94+
# Lifecycle
95+
# ------------------------------------------------------------------ #
96+
97+
def on_initialize(self, state: State, **kwargs) -> bool:
98+
if self.end and self.end != -1:
99+
raise ValueError(
100+
f"{self.__class__.__name__} can only be applied "
101+
f"during one-shot. Expected end to be None or "
102+
f"-1, got {self.end}"
103+
)
104+
if self.start and self.start != -1:
105+
raise ValueError(
106+
f"{self.__class__.__name__} can only be applied "
107+
f"during one-shot. Expected start to be None "
108+
f"or -1, got {self.start}"
109+
)
110+
111+
self._resolve_targets(state.model)
112+
return True
113+
114+
def on_start(self, state: State, event: Event, **kwargs):
115+
self.started_ = True
116+
self._register_accumulation_hooks(state.model)
117+
118+
def on_event(self, state: State, event: Event, **kwargs):
119+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
120+
if not self.started_:
121+
self.on_start(state, None)
122+
123+
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
124+
self._compute_and_attach(state.model)
125+
126+
if event.type_ == EventType.CALIBRATION_EPOCH_END:
127+
self._compute_and_attach(state.model, offload_to_cpu=True)
128+
129+
if not self.ended_:
130+
self.on_end(state, None)
131+
132+
def on_end(self, state: State, event: Event, **kwargs):
133+
self.ended_ = True
134+
self.remove_hooks()
135+
136+
def on_finalize(self, state: State, **kwargs) -> bool:
137+
if not self.ended_:
138+
self.on_end(state, None)
139+
140+
self._sums = None
141+
self._counts = None
142+
self._target_names = None
143+
return True
144+
145+
# ------------------------------------------------------------------ #
146+
# Target resolution
147+
# ------------------------------------------------------------------ #
148+
149+
def _resolve_targets(self, model: Module):
150+
"""Identify target modules using compressed_tensors matching."""
151+
self._target_names = []
152+
self._sums = {}
153+
self._counts = {}
154+
155+
for name, module in match_named_modules(model, self.targets, self.ignore):
156+
if not hasattr(module, "in_features"):
157+
continue
158+
159+
self._target_names.append(name)
160+
self._sums[name] = torch.zeros(module.in_features, dtype=torch.float32)
161+
self._counts[name] = 0
162+
163+
logger.info(f"IMatrixGatherer: targeting {len(self._target_names)}" f" modules")
164+
165+
# ------------------------------------------------------------------ #
166+
# Hook registration
167+
# ------------------------------------------------------------------ #
168+
169+
def _register_accumulation_hooks(self, model: Module):
170+
"""Attach a forward-pre hook to every target module."""
171+
172+
def _create_hook_fn(layer_name: str):
173+
"""Closure captures layer_name."""
174+
175+
def _hook(module: Module, args):
176+
x = args[0] if not isinstance(args, torch.Tensor) else args
177+
if isinstance(x, tuple):
178+
x = x[0]
179+
if not isinstance(x, torch.Tensor):
180+
return
181+
182+
# Mean per sample (each sample weighted equally,
183+
# regardless of sequence length)
184+
x_f = x.detach().float()
185+
sample_mean = x_f.pow(2).mean(dim=list(range(x_f.dim() - 1)))
186+
187+
device = self._sums[layer_name].device
188+
if device != sample_mean.device:
189+
self._sums[layer_name] = self._sums[layer_name].to(
190+
sample_mean.device
191+
)
192+
193+
self._sums[layer_name].add_(sample_mean)
194+
self._counts[layer_name] += 1
195+
196+
return _hook
197+
198+
for name, module in match_named_modules(model, self.targets, self.ignore):
199+
if name in self._sums:
200+
self.register_hook(
201+
module,
202+
_create_hook_fn(name),
203+
"forward_pre",
204+
)
205+
206+
# ------------------------------------------------------------------ #
207+
# Compute & attach
208+
# ------------------------------------------------------------------ #
209+
210+
def _compute_and_attach(self, model: Module, offload_to_cpu: bool = False):
211+
"""
212+
Compute E[x²] and store on each module.
213+
214+
:param model: model whose modules receive importance data
215+
:param offload_to_cpu: if True, move importance tensors to CPU
216+
after attaching. Set at CALIBRATION_EPOCH_END to free
217+
GPU memory before quantization.
218+
"""
219+
attached = 0
220+
for name, module in match_named_modules(model, self.targets, self.ignore):
221+
if name not in self._sums:
222+
continue
223+
224+
count = self._counts[name]
225+
if count == 0:
226+
continue
227+
228+
importance = self._sums[name] / count
229+
230+
if offload_to_cpu:
231+
importance = importance.to("cpu")
232+
# also free the accumulator
233+
del self._sums[name]
234+
235+
module._imatrix_importance = importance
236+
237+
attached += 1
238+
logger.debug(
239+
f"iMatrix {name}: "
240+
f"mean={importance.mean():.4f}, "
241+
f"max={importance.max():.4f}, "
242+
f"ratio="
243+
f"{importance.max() / (importance.mean() + 1e-10):.1f}"
244+
)
245+
246+
logger.info(
247+
f"IMatrixGatherer: attached importance to "
248+
f"{attached} modules" + (" (offloaded to CPU)" if offload_to_cpu else "")
249+
)

src/llmcompressor/observers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .moving_base import *
1515
from .min_max import *
1616
from .mse import *
17+
from .imatrix import *

0 commit comments

Comments
 (0)