|
9 | 9 | from loguru import logger |
10 | 10 | from pydantic import Field, PrivateAttr, field_validator, model_validator |
11 | 11 |
|
12 | | -from llmcompressor.core import State |
| 12 | +from llmcompressor.core import Event, EventType, State |
| 13 | +from llmcompressor.modifiers.modifier import Modifier |
13 | 14 | from llmcompressor.modifiers.utils.hooks import HooksMixin |
14 | 15 | from llmcompressor.pipelines.basic import run_pipeline as run_basic |
15 | 16 | from llmcompressor.utils.pytorch.module import ( |
|
20 | 21 | ) |
21 | 22 |
|
22 | 23 |
|
23 | | -class SparsityModifierMixin(HooksMixin): |
| 24 | +class SparsityModifierMixin(Modifier): |
24 | 25 | # modifier arguments |
25 | 26 | sparsity: Optional[Union[float, List[float]]] |
26 | 27 | sparsity_profile: Optional[str] = None |
@@ -93,6 +94,10 @@ def calibrate_module( |
93 | 94 | ): |
94 | 95 | raise NotImplementedError() |
95 | 96 |
|
| 97 | + @abstractmethod |
| 98 | + def compress_modules(self): |
| 99 | + raise NotImplementedError() |
| 100 | + |
96 | 101 | def on_initialize(self, state: "State", **kwargs) -> bool: |
97 | 102 | """ |
98 | 103 | Initialize and run the OBCQ algorithm on the current state |
@@ -158,6 +163,21 @@ def on_initialize(self, state: "State", **kwargs) -> bool: |
158 | 163 |
|
159 | 164 | return True |
160 | 165 |
|
| 166 | + def on_event(self, state: State, event: Event, **kwargs): |
| 167 | + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: |
| 168 | + self.compress_modules() |
| 169 | + |
| 170 | + if event.type_ == EventType.CALIBRATION_EPOCH_END: |
| 171 | + self.compress_modules() |
| 172 | + |
| 173 | + # TODO: modify lifecycle to end on calibration epoch end |
| 174 | + if not self.ended_: |
| 175 | + self.on_end(state, None) |
| 176 | + |
| 177 | + def on_end(self, state: State, event: Event, **kwargs): |
| 178 | + self.ended_ = True # TODO: move to super call |
| 179 | + self.remove_hooks() |
| 180 | + |
161 | 181 | def _infer_sequential_targets( |
162 | 182 | self, model: torch.nn.Module |
163 | 183 | ) -> Union[str, List[str]]: |
|
0 commit comments