|
9 | 9 | from loguru import logger |
10 | 10 | from pydantic import Field, PrivateAttr, field_validator, model_validator |
11 | 11 |
|
12 | | -from llmcompressor.core import State |
| 12 | +from llmcompressor.core import Event, EventType, State |
| 13 | +from llmcompressor.modifiers.modifier import Modifier |
13 | 14 | from llmcompressor.modifiers.utils.hooks import HooksMixin |
14 | 15 | from llmcompressor.pipelines.basic import run_pipeline as run_basic |
15 | | -from llmcompressor.pipelines.layer_sequential import ( |
16 | | - run_pipeline as run_layer_sequential, |
17 | | -) |
18 | | -from llmcompressor.pipelines.sequential import run_pipeline as run_sequential |
19 | 16 | from llmcompressor.utils.pytorch.module import ( |
20 | 17 | get_layers, |
21 | 18 | get_no_split_params, |
|
24 | 21 | ) |
25 | 22 |
|
26 | 23 |
|
27 | | -class SparsityModifierMixin(HooksMixin): |
| 24 | +class SparsityModifierMixin(Modifier): |
28 | 25 | # modifier arguments |
29 | 26 | sparsity: Optional[Union[float, List[float]]] |
30 | 27 | sparsity_profile: Optional[str] = None |
@@ -97,6 +94,10 @@ def calibrate_module( |
97 | 94 | ): |
98 | 95 | raise NotImplementedError() |
99 | 96 |
|
| 97 | + @abstractmethod |
| 98 | + def compress_modules(self): |
| 99 | + raise NotImplementedError() |
| 100 | + |
100 | 101 | def on_initialize(self, state: "State", **kwargs) -> bool: |
101 | 102 | """ |
102 | 103 | Initialize and run the OBCQ algorithm on the current state |
@@ -160,48 +161,22 @@ def on_initialize(self, state: "State", **kwargs) -> bool: |
160 | 161 | self._module_sparsities[module] = layer_sparsity |
161 | 162 | self.register_hook(module, self.calibrate_module, "forward") |
162 | 163 |
|
163 | | - # infer and run pipeline |
164 | | - model_name = state.model.__class__.__name__ |
165 | | - input_names = dataloader.dataset.column_names |
166 | | - unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError) |
167 | | - try: |
168 | | - run_sequential( |
169 | | - state.model, |
170 | | - state.data.calib, |
171 | | - self.sequential_targets, |
172 | | - self.ignore, |
173 | | - self, |
174 | | - ) |
175 | | - return True |
176 | | - |
177 | | - except Exception as exception: |
178 | | - if isinstance(exception, torch.fx.proxy.TraceError): |
179 | | - warnings.warn(f"Failed to trace {model_name} with inputs {input_names}") |
180 | | - if isinstance(exception, unfixable_errors): |
181 | | - raise exception |
182 | | - |
183 | | - warnings.warn("Falling back to layer_sequential pipeline") |
184 | | - try: |
185 | | - run_layer_sequential( |
186 | | - state.model, |
187 | | - state.data.calib, |
188 | | - self.sequential_targets, |
189 | | - self, |
190 | | - ) |
191 | | - return True |
192 | | - |
193 | | - except Exception as exception: |
194 | | - if isinstance(exception, TypeError): |
195 | | - warnings.warn(f"{model_name} fails layer-wise assumptions") |
196 | | - if isinstance(exception, unfixable_errors): |
197 | | - raise exception |
198 | | - |
199 | | - warnings.warn( |
200 | | - "Falling back to basic pipeline, which requires extra memory and " |
201 | | - "may result in decreased accuracy" |
202 | | - ) |
203 | | - run_basic(state.model, state.data.calib, self) |
204 | | - return True |
| 164 | + return True |
| 165 | + |
| 166 | + def on_event(self, state: State, event: Event, **kwargs): |
| 167 | + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: |
| 168 | + self.compress_modules() |
| 169 | + |
| 170 | + if event.type_ == EventType.CALIBRATION_EPOCH_END: |
| 171 | + self.compress_modules() |
| 172 | + |
| 173 | + # TODO: modify lifecycle to end on calibration epoch end |
| 174 | + if not self.ended_: |
| 175 | + self.on_end(state, None) |
| 176 | + |
| 177 | + def on_end(self, state: State, event: Event, **kwargs): |
| 178 | + self.ended_ = True # TODO: move to super call |
| 179 | + self.remove_hooks() |
205 | 180 |
|
206 | 181 | def _infer_sequential_targets( |
207 | 182 | self, model: torch.nn.Module |
|
0 commit comments