Skip to content

Commit 09ff1d8

Browse files
Move OBCQ to SparseGPT (#1842)
SUMMARY: Fix for #1834 TEST PLAN: - New path imports - Legacy shim import + warning --------- Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]>
1 parent bfa02eb commit 09ff1d8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+368
-307
lines changed

.github/workflows/test-check-transformers.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ jobs:
101101
if: (success() || failure()) && steps.install.outcome == 'success'
102102
run: |
103103
pytest -v tests/llmcompressor/transformers/oneshot
104-
- name: Running OBCQ Tests
104+
- name: Running SparseGPT Tests
105105
if: (success() || failure()) && steps.install.outcome == 'success'
106106
run: |
107-
pytest -v tests/llmcompressor/transformers/obcq
107+
pytest -v tests/llmcompressor/transformers/sparsegpt
108108
- name: Running Tracing Tests
109109
if: (success() || failure()) && steps.install.outcome == 'success'
110110
run: |

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,3 +804,8 @@ wandb/
804804
timings/
805805
output_finetune/
806806
env_log.json
807+
808+
# uv artifacts
809+
uv.lock
810+
.venv/
811+

src/llmcompressor/modifiers/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ are relevant only during training. Below is a summary of the key modifiers avail
88

99
Modifiers that introduce sparsity into a model
1010

11-
### [SparseGPT](./obcq/base.py)
11+
### [SparseGPT](./pruning/sparsegpt/base.py)
1212
One-shot algorithm that uses calibration data to introduce unstructured or structured
1313
sparsity into weights. Implementation based on [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774). A small amount of calibration data is used
1414
to calculate a Hessian for each layers input activations, this Hessian is then used to

src/llmcompressor/modifiers/factory.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,17 @@ def load_from_package(package_path: str) -> Dict[str, Type[Modifier]]:
4444
loaded = {}
4545
main_package = importlib.import_module(package_path)
4646

47-
for importer, modname, is_pkg in pkgutil.walk_packages(
47+
# exclude deprecated packages from registry so
48+
# their new location is used instead
49+
deprecated_packages = [
50+
"llmcompressor.modifiers.obcq",
51+
"llmcompressor.modifiers.obcq.sgpt_base",
52+
]
53+
for _importer, modname, _is_pkg in pkgutil.walk_packages(
4854
main_package.__path__, package_path + "."
4955
):
56+
if modname in deprecated_packages:
57+
continue
5058
try:
5159
module = importlib.import_module(modname)
5260

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# ruff: noqa
22

3-
from .base import *
3+
from .sgpt_base import *
Lines changed: 13 additions & 255 deletions
Original file line numberDiff line numberDiff line change
@@ -1,262 +1,20 @@
1-
from abc import abstractmethod
2-
from collections import defaultdict
3-
from functools import partial
4-
from typing import Any, Dict, List, Optional, Tuple, Union
1+
import warnings
52

6-
import numpy
7-
import torch
8-
from loguru import logger
9-
from pydantic import Field, PrivateAttr, field_validator, model_validator
10-
11-
from llmcompressor.core import Event, EventType, State
12-
from llmcompressor.modifiers.modifier import Modifier
13-
from llmcompressor.modifiers.utils.hooks import HooksMixin
14-
from llmcompressor.utils.pytorch.module import (
15-
get_layers,
16-
get_no_split_params,
17-
get_prunable_layers,
18-
match_targets,
3+
from llmcompressor.modifiers.pruning.sparsegpt import (
4+
SparseGPTModifier as PruningSparseGPTModifier,
195
)
206

7+
__all__ = ["SparseGPTModifier"]
218

22-
class SparsityModifierBase(Modifier):
23-
"""
24-
Abstract base class which implements functionality related to oneshot sparsity.
25-
Inheriters must implement `calibrate_module` and `compress_modules`
26-
"""
27-
28-
# modifier arguments
29-
sparsity: Optional[Union[float, List[float]]]
30-
sparsity_profile: Optional[str] = None
31-
mask_structure: str = "0:0"
32-
owl_m: Optional[int] = None
33-
owl_lmbda: Optional[float] = None
34-
35-
# data pipeline arguments
36-
sequential_targets: Union[str, List[str], None] = None
37-
targets: Union[str, List[str]] = ["Linear"]
38-
ignore: List[str] = Field(default_factory=list)
39-
40-
# private variables
41-
_prune_n: Optional[int] = PrivateAttr(default=None)
42-
_prune_m: Optional[int] = PrivateAttr(default=None)
43-
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
44-
_target_layers: Dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict)
45-
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
46-
47-
@field_validator("sparsity_profile", mode="before")
48-
def validate_sparsity_profile(cls, value: Optional[str]) -> bool:
49-
if value is None:
50-
return value
51-
52-
value = value.lower()
53-
54-
profile_options = ["owl"]
55-
if value not in profile_options:
56-
raise ValueError(f"Please choose profile from {profile_options}")
57-
58-
return value
59-
60-
@model_validator(mode="after")
61-
def validate_model_after(model: "SparsityModifierBase") -> "SparsityModifierBase":
62-
profile = model.sparsity_profile
63-
owl_m = model.owl_m
64-
owl_lmbda = model.owl_lmbda
65-
mask_structure = model.mask_structure
66-
67-
has_owl_m = owl_m is not None
68-
has_owl_lmbda = owl_lmbda is not None
69-
has_owl = profile == "owl"
70-
owl_args = (has_owl_m, has_owl_lmbda, has_owl)
71-
if any(owl_args) and not all(owl_args):
72-
raise ValueError(
73-
'Must provide all of `profile="owl"`, `owl_m` and `owl_lmbda` or none'
74-
)
75-
76-
model._prune_n, model._prune_m = model._split_mask_structure(mask_structure)
77-
78-
return model
79-
80-
@abstractmethod
81-
def calibrate_module(
82-
self,
83-
module: torch.nn.Module,
84-
args: Tuple[torch.Tensor, ...],
85-
_output: torch.Tensor,
86-
):
87-
raise NotImplementedError()
88-
89-
@abstractmethod
90-
def compress_modules(self):
91-
raise NotImplementedError()
92-
93-
def on_initialize(self, state: "State", **kwargs) -> bool:
94-
"""
95-
Initialize and run the OBCQ algorithm on the current state
96-
97-
:param state: session state storing input model and calibration data
98-
"""
99-
model: torch.nn.Module = state.model
100-
dataloader: torch.utils.data.DataLoader = state.data.calib
101-
102-
# infer module and sequential targets
103-
self.sequential_targets = self._infer_sequential_targets(model)
104-
layers = get_layers(self.sequential_targets, model)
105-
self._target_layers = get_layers(
106-
self.targets, model
107-
) # layers containing targets
108-
109-
# infer layer sparsities
110-
if self.sparsity_profile == "owl":
111-
logger.info(
112-
"Using OWL to infer target layer-wise sparsities from "
113-
f"{len(dataloader) if dataloader else 0} calibration samples..."
114-
)
115-
self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader)
116-
117-
# get layers and validate sparsity
118-
if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len(
119-
self.sparsity
120-
):
121-
raise ValueError(
122-
f"{self.__repr_name__} was initialized with {len(self.sparsity)} "
123-
f"sparsities values, but model has {len(layers)} target layers"
124-
)
125-
126-
return True
127-
128-
def on_start(self, state: State, event: Event, **kwargs):
129-
self.started_ = True
9+
# Legacy shim for backwards-compat imports
13010

131-
# register hooks
132-
for index, (layer_name, layer) in enumerate(self._target_layers.items()):
133-
if isinstance(self.sparsity, dict):
134-
layer_sparsity = self.sparsity[layer_name]
135-
elif isinstance(self.sparsity, list):
136-
layer_sparsity = self.sparsity[index]
137-
else:
138-
layer_sparsity = self.sparsity
13911

140-
for name, module in get_prunable_layers(layer).items():
141-
name = f"{layer_name}.{name}"
142-
143-
if match_targets(name, self.ignore)[0]:
144-
continue
145-
146-
# HACK: previously, embeddings were not quantized because they were not
147-
# accessible by the layer compressor. For now, we manually ignore it,
148-
# but in the FUTURE this should be ignored by the user
149-
if isinstance(module, torch.nn.Embedding):
150-
continue
151-
152-
if name.endswith("lm_head"):
153-
logger.warning(
154-
"`lm_head` was previously auto-ignored by SparseGPT and Wanda "
155-
"modifiers and is not advised. Please add `re:.*lm_head` to "
156-
"your ignore list if this was unintentional"
157-
)
158-
159-
self._module_names[module] = name
160-
self._module_sparsities[module] = layer_sparsity
161-
self.register_hook(module, self.calibrate_module, "forward")
162-
163-
def on_event(self, state: State, event: Event, **kwargs):
164-
if event.type_ == EventType.CALIBRATION_EPOCH_START:
165-
if not self.started_:
166-
self.on_start(state, None)
167-
168-
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
169-
self.compress_modules()
170-
171-
if event.type_ == EventType.CALIBRATION_EPOCH_END:
172-
self.compress_modules()
173-
174-
if not self.ended_:
175-
self.on_end(state, None)
176-
177-
def on_end(self, state: State, event: Event, **kwargs):
178-
self.ended_ = True
179-
self.remove_hooks()
180-
181-
def _infer_sequential_targets(
182-
self, model: torch.nn.Module
183-
) -> Union[str, List[str]]:
184-
if self.sequential_targets is None:
185-
return get_no_split_params(model)
186-
if isinstance(self.sequential_targets, str):
187-
return [self.sequential_targets]
188-
return self.sequential_targets
189-
190-
def _infer_owl_layer_sparsity(
191-
self,
192-
model: torch.nn.Module,
193-
layers: Dict[str, torch.nn.Module],
194-
dataloader: torch.utils.data.DataLoader,
195-
) -> Dict[str, float]:
196-
activations = self._get_activations(model, dataloader)
197-
198-
groups = {}
199-
for name, layer in layers.items():
200-
prunable_layers = get_prunable_layers(layer)
201-
z = [
202-
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
203-
for n, m in prunable_layers.items()
204-
]
205-
groups[name] = torch.cat([item.flatten().cpu() for item in z])
206-
207-
del activations
208-
209-
outlier_ratios = {}
210-
for group in groups:
211-
threshold = torch.mean(groups[group]) * self.owl_m
212-
outlier_ratios[group] = (
213-
100 * (groups[group] > threshold).sum().item() / groups[group].numel()
214-
)
215-
outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios])
216-
for k in outlier_ratios:
217-
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
218-
1
219-
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min())
220-
* self.owl_lmbda
221-
* 2
222-
)
223-
outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios])
224-
sparsities = {
225-
k: 1
226-
- (
227-
outlier_ratios[k]
228-
- numpy.mean(outlier_ratios_arr)
229-
+ (1 - float(self.sparsity))
230-
)
231-
for k in outlier_ratios
232-
}
233-
logger.info(f"OWL sparsities for sp={self.sparsity} are:")
234-
for k in sparsities:
235-
logger.info(f"Sparsity for {k}: {sparsities[k]}")
236-
return sparsities
237-
238-
def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]:
239-
from llmcompressor.pipelines.basic import run_calibration
240-
241-
acts = defaultdict(int)
242-
243-
def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
244-
nonlocal acts
245-
if isinstance(input, tuple):
246-
input = input[0]
247-
acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt()
248-
249-
hooks = set(
250-
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
251-
for name, mod in model.named_modules()
252-
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
12+
class SparseGPTModifier(PruningSparseGPTModifier):
13+
def __init__(cls, **kwargs):
14+
warnings.warn(
15+
"SparseGPTModifier has moved. In future, please initialize it from "
16+
"`llmcompressor.modifiers.pruning.sparsegpt.SparseGPTModifier`.",
17+
DeprecationWarning,
18+
stacklevel=2, # Adjust stacklevel to point to the user's code
25319
)
254-
with HooksMixin.disable_hooks(keep=hooks):
255-
run_calibration(model, dataloader)
256-
self.remove_hooks(hooks)
257-
258-
return acts
259-
260-
def _split_mask_structure(self, mask_structure: str) -> Tuple[int, int]:
261-
n, m = mask_structure.split(":")
262-
return int(n), int(m)
20+
return super().__init__(**kwargs)

src/llmcompressor/modifiers/pruning/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .constant import *
44
from .magnitude import *
55
from .wanda import *
6+
from .sparsegpt import *
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# ruff: noqa
2+
3+
from .base import SparseGPTModifier
4+
5+
__all__ = ["SparseGPTModifier"]

src/llmcompressor/modifiers/obcq/base.py renamed to src/llmcompressor/modifiers/pruning/sparsegpt/base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from pydantic import PrivateAttr
1212

1313
from llmcompressor.core import State
14-
from llmcompressor.modifiers.obcq.sgpt_base import SparsityModifierBase
15-
from llmcompressor.modifiers.obcq.sgpt_sparsify import (
14+
from llmcompressor.modifiers.pruning.sparsegpt.sgpt_base import SparsityModifierBase
15+
from llmcompressor.modifiers.pruning.sparsegpt.sgpt_sparsify import (
1616
accumulate_hessian,
1717
make_empty_hessian,
1818
sparsify_weight,
@@ -62,9 +62,10 @@ class SparseGPTModifier(SparsityModifierBase):
6262
previously pruned model, defaults to False.
6363
:param offload_hessians: Set to True for decreased memory usage but increased
6464
runtime.
65-
:param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__'
66-
to compress every layer in the model. Alias for `targets`
67-
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
65+
:param sequential_targets: list of layer names to compress
66+
during SparseGPT, or '__ALL__' to compress every layer
67+
in the model. Alias for `targets`
68+
:param targets: list of layer names to compress during SparseGPT, or '__ALL__'
6869
to compress every layer in the model. Alias for `sequential_targets`
6970
:param ignore: optional list of module class names or submodule names to not
7071
quantize even if they match a target. Defaults to empty list.

0 commit comments

Comments
 (0)