|
1 |
| -from abc import abstractmethod |
2 |
| -from collections import defaultdict |
3 |
| -from functools import partial |
4 |
| -from typing import Any, Dict, List, Optional, Tuple, Union |
| 1 | +import warnings |
5 | 2 |
|
6 |
| -import numpy |
7 |
| -import torch |
8 |
| -from loguru import logger |
9 |
| -from pydantic import Field, PrivateAttr, field_validator, model_validator |
10 |
| - |
11 |
| -from llmcompressor.core import Event, EventType, State |
12 |
| -from llmcompressor.modifiers.modifier import Modifier |
13 |
| -from llmcompressor.modifiers.utils.hooks import HooksMixin |
14 |
| -from llmcompressor.utils.pytorch.module import ( |
15 |
| - get_layers, |
16 |
| - get_no_split_params, |
17 |
| - get_prunable_layers, |
18 |
| - match_targets, |
| 3 | +from llmcompressor.modifiers.pruning.sparsegpt import ( |
| 4 | + SparseGPTModifier as PruningSparseGPTModifier, |
19 | 5 | )
|
20 | 6 |
|
| 7 | +__all__ = ["SparseGPTModifier"] |
21 | 8 |
|
22 |
| -class SparsityModifierBase(Modifier): |
23 |
| - """ |
24 |
| - Abstract base class which implements functionality related to oneshot sparsity. |
25 |
| - Inheriters must implement `calibrate_module` and `compress_modules` |
26 |
| - """ |
27 |
| - |
28 |
| - # modifier arguments |
29 |
| - sparsity: Optional[Union[float, List[float]]] |
30 |
| - sparsity_profile: Optional[str] = None |
31 |
| - mask_structure: str = "0:0" |
32 |
| - owl_m: Optional[int] = None |
33 |
| - owl_lmbda: Optional[float] = None |
34 |
| - |
35 |
| - # data pipeline arguments |
36 |
| - sequential_targets: Union[str, List[str], None] = None |
37 |
| - targets: Union[str, List[str]] = ["Linear"] |
38 |
| - ignore: List[str] = Field(default_factory=list) |
39 |
| - |
40 |
| - # private variables |
41 |
| - _prune_n: Optional[int] = PrivateAttr(default=None) |
42 |
| - _prune_m: Optional[int] = PrivateAttr(default=None) |
43 |
| - _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) |
44 |
| - _target_layers: Dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict) |
45 |
| - _module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) |
46 |
| - |
47 |
| - @field_validator("sparsity_profile", mode="before") |
48 |
| - def validate_sparsity_profile(cls, value: Optional[str]) -> bool: |
49 |
| - if value is None: |
50 |
| - return value |
51 |
| - |
52 |
| - value = value.lower() |
53 |
| - |
54 |
| - profile_options = ["owl"] |
55 |
| - if value not in profile_options: |
56 |
| - raise ValueError(f"Please choose profile from {profile_options}") |
57 |
| - |
58 |
| - return value |
59 |
| - |
60 |
| - @model_validator(mode="after") |
61 |
| - def validate_model_after(model: "SparsityModifierBase") -> "SparsityModifierBase": |
62 |
| - profile = model.sparsity_profile |
63 |
| - owl_m = model.owl_m |
64 |
| - owl_lmbda = model.owl_lmbda |
65 |
| - mask_structure = model.mask_structure |
66 |
| - |
67 |
| - has_owl_m = owl_m is not None |
68 |
| - has_owl_lmbda = owl_lmbda is not None |
69 |
| - has_owl = profile == "owl" |
70 |
| - owl_args = (has_owl_m, has_owl_lmbda, has_owl) |
71 |
| - if any(owl_args) and not all(owl_args): |
72 |
| - raise ValueError( |
73 |
| - 'Must provide all of `profile="owl"`, `owl_m` and `owl_lmbda` or none' |
74 |
| - ) |
75 |
| - |
76 |
| - model._prune_n, model._prune_m = model._split_mask_structure(mask_structure) |
77 |
| - |
78 |
| - return model |
79 |
| - |
80 |
| - @abstractmethod |
81 |
| - def calibrate_module( |
82 |
| - self, |
83 |
| - module: torch.nn.Module, |
84 |
| - args: Tuple[torch.Tensor, ...], |
85 |
| - _output: torch.Tensor, |
86 |
| - ): |
87 |
| - raise NotImplementedError() |
88 |
| - |
89 |
| - @abstractmethod |
90 |
| - def compress_modules(self): |
91 |
| - raise NotImplementedError() |
92 |
| - |
93 |
| - def on_initialize(self, state: "State", **kwargs) -> bool: |
94 |
| - """ |
95 |
| - Initialize and run the OBCQ algorithm on the current state |
96 |
| -
|
97 |
| - :param state: session state storing input model and calibration data |
98 |
| - """ |
99 |
| - model: torch.nn.Module = state.model |
100 |
| - dataloader: torch.utils.data.DataLoader = state.data.calib |
101 |
| - |
102 |
| - # infer module and sequential targets |
103 |
| - self.sequential_targets = self._infer_sequential_targets(model) |
104 |
| - layers = get_layers(self.sequential_targets, model) |
105 |
| - self._target_layers = get_layers( |
106 |
| - self.targets, model |
107 |
| - ) # layers containing targets |
108 |
| - |
109 |
| - # infer layer sparsities |
110 |
| - if self.sparsity_profile == "owl": |
111 |
| - logger.info( |
112 |
| - "Using OWL to infer target layer-wise sparsities from " |
113 |
| - f"{len(dataloader) if dataloader else 0} calibration samples..." |
114 |
| - ) |
115 |
| - self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader) |
116 |
| - |
117 |
| - # get layers and validate sparsity |
118 |
| - if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len( |
119 |
| - self.sparsity |
120 |
| - ): |
121 |
| - raise ValueError( |
122 |
| - f"{self.__repr_name__} was initialized with {len(self.sparsity)} " |
123 |
| - f"sparsities values, but model has {len(layers)} target layers" |
124 |
| - ) |
125 |
| - |
126 |
| - return True |
127 |
| - |
128 |
| - def on_start(self, state: State, event: Event, **kwargs): |
129 |
| - self.started_ = True |
| 9 | +# Legacy shim for backwards-compat imports |
130 | 10 |
|
131 |
| - # register hooks |
132 |
| - for index, (layer_name, layer) in enumerate(self._target_layers.items()): |
133 |
| - if isinstance(self.sparsity, dict): |
134 |
| - layer_sparsity = self.sparsity[layer_name] |
135 |
| - elif isinstance(self.sparsity, list): |
136 |
| - layer_sparsity = self.sparsity[index] |
137 |
| - else: |
138 |
| - layer_sparsity = self.sparsity |
139 | 11 |
|
140 |
| - for name, module in get_prunable_layers(layer).items(): |
141 |
| - name = f"{layer_name}.{name}" |
142 |
| - |
143 |
| - if match_targets(name, self.ignore)[0]: |
144 |
| - continue |
145 |
| - |
146 |
| - # HACK: previously, embeddings were not quantized because they were not |
147 |
| - # accessible by the layer compressor. For now, we manually ignore it, |
148 |
| - # but in the FUTURE this should be ignored by the user |
149 |
| - if isinstance(module, torch.nn.Embedding): |
150 |
| - continue |
151 |
| - |
152 |
| - if name.endswith("lm_head"): |
153 |
| - logger.warning( |
154 |
| - "`lm_head` was previously auto-ignored by SparseGPT and Wanda " |
155 |
| - "modifiers and is not advised. Please add `re:.*lm_head` to " |
156 |
| - "your ignore list if this was unintentional" |
157 |
| - ) |
158 |
| - |
159 |
| - self._module_names[module] = name |
160 |
| - self._module_sparsities[module] = layer_sparsity |
161 |
| - self.register_hook(module, self.calibrate_module, "forward") |
162 |
| - |
163 |
| - def on_event(self, state: State, event: Event, **kwargs): |
164 |
| - if event.type_ == EventType.CALIBRATION_EPOCH_START: |
165 |
| - if not self.started_: |
166 |
| - self.on_start(state, None) |
167 |
| - |
168 |
| - if event.type_ == EventType.SEQUENTIAL_EPOCH_END: |
169 |
| - self.compress_modules() |
170 |
| - |
171 |
| - if event.type_ == EventType.CALIBRATION_EPOCH_END: |
172 |
| - self.compress_modules() |
173 |
| - |
174 |
| - if not self.ended_: |
175 |
| - self.on_end(state, None) |
176 |
| - |
177 |
| - def on_end(self, state: State, event: Event, **kwargs): |
178 |
| - self.ended_ = True |
179 |
| - self.remove_hooks() |
180 |
| - |
181 |
| - def _infer_sequential_targets( |
182 |
| - self, model: torch.nn.Module |
183 |
| - ) -> Union[str, List[str]]: |
184 |
| - if self.sequential_targets is None: |
185 |
| - return get_no_split_params(model) |
186 |
| - if isinstance(self.sequential_targets, str): |
187 |
| - return [self.sequential_targets] |
188 |
| - return self.sequential_targets |
189 |
| - |
190 |
| - def _infer_owl_layer_sparsity( |
191 |
| - self, |
192 |
| - model: torch.nn.Module, |
193 |
| - layers: Dict[str, torch.nn.Module], |
194 |
| - dataloader: torch.utils.data.DataLoader, |
195 |
| - ) -> Dict[str, float]: |
196 |
| - activations = self._get_activations(model, dataloader) |
197 |
| - |
198 |
| - groups = {} |
199 |
| - for name, layer in layers.items(): |
200 |
| - prunable_layers = get_prunable_layers(layer) |
201 |
| - z = [ |
202 |
| - m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) |
203 |
| - for n, m in prunable_layers.items() |
204 |
| - ] |
205 |
| - groups[name] = torch.cat([item.flatten().cpu() for item in z]) |
206 |
| - |
207 |
| - del activations |
208 |
| - |
209 |
| - outlier_ratios = {} |
210 |
| - for group in groups: |
211 |
| - threshold = torch.mean(groups[group]) * self.owl_m |
212 |
| - outlier_ratios[group] = ( |
213 |
| - 100 * (groups[group] > threshold).sum().item() / groups[group].numel() |
214 |
| - ) |
215 |
| - outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios]) |
216 |
| - for k in outlier_ratios: |
217 |
| - outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * ( |
218 |
| - 1 |
219 |
| - / (outlier_ratios_arr.max() - outlier_ratios_arr.min()) |
220 |
| - * self.owl_lmbda |
221 |
| - * 2 |
222 |
| - ) |
223 |
| - outlier_ratios_arr = numpy.array([outlier_ratios[k] for k in outlier_ratios]) |
224 |
| - sparsities = { |
225 |
| - k: 1 |
226 |
| - - ( |
227 |
| - outlier_ratios[k] |
228 |
| - - numpy.mean(outlier_ratios_arr) |
229 |
| - + (1 - float(self.sparsity)) |
230 |
| - ) |
231 |
| - for k in outlier_ratios |
232 |
| - } |
233 |
| - logger.info(f"OWL sparsities for sp={self.sparsity} are:") |
234 |
| - for k in sparsities: |
235 |
| - logger.info(f"Sparsity for {k}: {sparsities[k]}") |
236 |
| - return sparsities |
237 |
| - |
238 |
| - def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]: |
239 |
| - from llmcompressor.pipelines.basic import run_calibration |
240 |
| - |
241 |
| - acts = defaultdict(int) |
242 |
| - |
243 |
| - def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str): |
244 |
| - nonlocal acts |
245 |
| - if isinstance(input, tuple): |
246 |
| - input = input[0] |
247 |
| - acts[name] += 1.0 / nsamples * input.pow(2).sum(dim=(0, 1)).sqrt() |
248 |
| - |
249 |
| - hooks = set( |
250 |
| - self.register_hook(mod, partial(save_acts, name=name), "forward_pre") |
251 |
| - for name, mod in model.named_modules() |
252 |
| - if isinstance(mod, torch.nn.Linear) and "lm_head" not in name |
| 12 | +class SparseGPTModifier(PruningSparseGPTModifier): |
| 13 | + def __init__(cls, **kwargs): |
| 14 | + warnings.warn( |
| 15 | + "SparseGPTModifier has moved. In future, please initialize it from " |
| 16 | + "`llmcompressor.modifiers.pruning.sparsegpt.SparseGPTModifier`.", |
| 17 | + DeprecationWarning, |
| 18 | + stacklevel=2, # Adjust stacklevel to point to the user's code |
253 | 19 | )
|
254 |
| - with HooksMixin.disable_hooks(keep=hooks): |
255 |
| - run_calibration(model, dataloader) |
256 |
| - self.remove_hooks(hooks) |
257 |
| - |
258 |
| - return acts |
259 |
| - |
260 |
| - def _split_mask_structure(self, mask_structure: str) -> Tuple[int, int]: |
261 |
| - n, m = mask_structure.split(":") |
262 |
| - return int(n), int(m) |
| 20 | + return super().__init__(**kwargs) |
0 commit comments