Skip to content

Commit b9c91e7

Browse files
committed
remove hooks on calibration epoch end
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 8888a33 commit b9c91e7

File tree

6 files changed

+80
-44
lines changed

6 files changed

+80
-44
lines changed

src/llmcompressor/modifiers/obcq/base.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from loguru import logger
1111
from pydantic import PrivateAttr
1212

13-
from llmcompressor.core import Event, EventType, State
13+
from llmcompressor.core import State
1414
from llmcompressor.modifiers import Modifier
1515
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
1616
from llmcompressor.modifiers.obcq.sgpt_sparsify import (
@@ -113,13 +113,6 @@ def calibrate_module(
113113
self._num_samples[module],
114114
)
115115

116-
def on_event(self, state: State, event: Event, **kwargs):
117-
if event.type_ in (
118-
EventType.SEQUENTIAL_EPOCH_END,
119-
EventType.CALIBRATION_EPOCH_END,
120-
):
121-
self.compress_modules()
122-
123116
def compress_modules(self):
124117
"""
125118
Sparsify modules which have been calibrated
@@ -163,10 +156,13 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
163156
self._hessians[module] = self._hessians[module].to(device="cpu")
164157

165158
def on_finalize(self, state: State, **kwargs) -> bool:
159+
# TODO: modify lifecycle to end on finalize
160+
if not self.ended_:
161+
self.on_end(state, None) # remove hooks
162+
166163
if len(self._num_samples) > 0:
167164
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")
168165

169-
self.remove_hooks()
170166
self._hessians = dict()
171167
self._num_samples = dict()
172168
self._module_names = dict()

src/llmcompressor/modifiers/obcq/sgpt_mixin.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from loguru import logger
1010
from pydantic import Field, PrivateAttr, field_validator, model_validator
1111

12-
from llmcompressor.core import State
12+
from llmcompressor.core import Event, EventType, State
13+
from llmcompressor.modifiers.modifier import Modifier
1314
from llmcompressor.modifiers.utils.hooks import HooksMixin
1415
from llmcompressor.pipelines.basic import run_pipeline as run_basic
1516
from llmcompressor.utils.pytorch.module import (
@@ -20,7 +21,7 @@
2021
)
2122

2223

23-
class SparsityModifierMixin(HooksMixin):
24+
class SparsityModifierMixin(Modifier):
2425
# modifier arguments
2526
sparsity: Optional[Union[float, List[float]]]
2627
sparsity_profile: Optional[str] = None
@@ -93,6 +94,10 @@ def calibrate_module(
9394
):
9495
raise NotImplementedError()
9596

97+
@abstractmethod
98+
def compress_modules(self):
99+
raise NotImplementedError()
100+
96101
def on_initialize(self, state: "State", **kwargs) -> bool:
97102
"""
98103
Initialize and run the OBCQ algorithm on the current state
@@ -158,6 +163,21 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
158163

159164
return True
160165

166+
def on_event(self, state: State, event: Event, **kwargs):
167+
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
168+
self.compress_modules()
169+
170+
if event.type_ == EventType.CALIBRATION_EPOCH_END:
171+
self.compress_modules()
172+
173+
# TODO: modify lifecycle to end on calibration epoch end
174+
if not self.ended_:
175+
self.on_end(state, None)
176+
177+
def on_end(self, state: State, event: Event, **kwargs):
178+
self.ended_ = True # TODO: move to super call
179+
self.remove_hooks()
180+
161181
def _infer_sequential_targets(
162182
self, model: torch.nn.Module
163183
) -> Union[str, List[str]]:

src/llmcompressor/modifiers/pruning/wanda/base.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from loguru import logger
1010
from pydantic import PrivateAttr
1111

12-
from llmcompressor.core import Event, EventType, State
12+
from llmcompressor.core import State
1313
from llmcompressor.modifiers import Modifier
1414
from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin
1515
from llmcompressor.modifiers.pruning.wanda.wanda_sparsify import (
@@ -99,13 +99,6 @@ def calibrate_module(
9999
self._num_samples[module],
100100
)
101101

102-
def on_event(self, state: State, event: Event, **kwargs):
103-
if event.type_ in (
104-
EventType.SEQUENTIAL_EPOCH_END,
105-
EventType.CALIBRATION_EPOCH_END,
106-
):
107-
self.compress_modules()
108-
109102
def compress_modules(self):
110103
"""
111104
Sparsify modules which have been calibrated
@@ -133,10 +126,13 @@ def compress_modules(self):
133126
del self._num_samples[module]
134127

135128
def on_finalize(self, state: State, **kwargs) -> bool:
129+
# TODO: modify lifecycle to end on finalize
130+
if not self.ended_:
131+
self.on_end(state, None) # remove hooks
132+
136133
if len(self._num_samples) > 0:
137134
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")
138135

139-
self.remove_hooks()
140136
self._row_scalars = dict()
141137
self._num_samples = dict()
142138
self._module_names = dict()

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,21 +160,41 @@ def on_initialize(self, state: State, **kwargs) -> bool:
160160

161161
return True
162162

163+
def on_event(self, state: State, event: Event, **kwargs):
164+
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
165+
self.compress_modules()
166+
167+
if event.type_ == EventType.CALIBRATION_EPOCH_END:
168+
self.compress_modules()
169+
170+
# TODO: modify lifecycle to end on calibration epoch end
171+
if not self.ended_:
172+
self.on_end(state, None)
173+
174+
def on_end(self, state: State, event: Event, **kwargs):
175+
"""
176+
Finish calibrating by removing observers and calibration hooks
177+
"""
178+
self.ended_ = True # TODO: move to super call
179+
state.model.apply(freeze_module_quantization) # remove observers
180+
self.remove_hooks() # remove hooks
181+
163182
def on_finalize(self, state: State, **kwargs) -> bool:
164183
"""
165184
disable the quantization observers used by the OBCQ algorithm
166185
167186
:param state: session state storing input model and calibration data
168187
"""
188+
# TODO: modify lifecycle to end on finalize
189+
if not self.ended_:
190+
self.on_end(state, None)
191+
169192
if len(self._num_samples) > 0:
170193
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")
171194

172195
self._hessians = dict()
173196
self._num_samples = dict()
174197

175-
state.model.apply(freeze_module_quantization) # remove observers
176-
self.remove_hooks() # remove hooks
177-
178198
return True
179199

180200
def calibrate_module(
@@ -211,13 +231,6 @@ def calibrate_module(
211231
self._num_samples[module],
212232
)
213233

214-
def on_event(self, state: State, event: Event, **kwargs):
215-
if event.type_ in (
216-
EventType.SEQUENTIAL_EPOCH_END,
217-
EventType.CALIBRATION_EPOCH_END,
218-
):
219-
self.compress_modules()
220-
221234
def compress_modules(self):
222235
"""
223236
Quantize modules which have been calibrated

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tqdm
22
from compressed_tensors.quantization import disable_quantization, enable_quantization
33

4-
from llmcompressor.core import Event, State
4+
from llmcompressor.core import Event, EventType, State
55
from llmcompressor.modifiers import Modifier
66
from llmcompressor.modifiers.quantization.calibration import (
77
apply_calibration_status,
@@ -81,14 +81,21 @@ def on_start(self, state: State):
8181
for module in tqdm.tqdm(modules, desc="Calibrating weights"):
8282
update_weight_zp_scale(module)
8383

84+
def on_event(self, state: State, event: Event, **kwargs):
85+
if event.type_ == EventType.CALIBRATION_EPOCH_END:
86+
# TODO: modify lifecycle to end on calibration epoch end
87+
if not self.ended_:
88+
self.on_end(state, None)
89+
8490
def on_end(self, state: State, event: Event, **kwargs):
8591
"""
8692
Finish calibrating by removing observers and calibration hooks
8793
"""
94+
self.ended_ = True # TODO: move to super call
8895
state.model.apply(freeze_module_quantization) # remove observers
8996
self.remove_hooks() # remove hooks
9097

9198
def on_finalize(self, state: State, **kwargs) -> bool:
92-
# TODO: modify lifecycle so modifiers end on finalize
99+
# TODO: modify lifecycle to end on finalize
93100
if not self.ended_:
94101
self.on_end(state, None)

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,27 +137,31 @@ def on_initialize(self, state: State, **kwargs) -> bool:
137137
return True
138138

139139
def on_event(self, state: State, event: Event, **kwargs):
140-
"""
141-
Sparsify modules which have been calibrated with samples
142-
"""
143-
if event.type_ in (
144-
EventType.SEQUENTIAL_EPOCH_END,
145-
EventType.CALIBRATION_EPOCH_END,
146-
):
140+
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
147141
self._apply_smoothing(state.model)
148142

143+
if event.type_ == EventType.CALIBRATION_EPOCH_END:
144+
self._apply_smoothing(state.model)
145+
146+
# TODO: modify lifecycle to end on calibration epoch end
147+
if not self.ended_:
148+
self.on_end(state, None)
149+
150+
def on_end(self, state: State, event: Event, **kwargs):
151+
self.ended_ = True # TODO: move to super calls
152+
self.remove_hooks() # remove hooks
153+
149154
def on_finalize(self, state: State, **kwargs) -> bool:
150155
"""
151156
Clean up by clearing the scale and mapping data
152-
153-
:param state: unused
154-
:return: True
155157
"""
158+
# TODO: modify lifecycle to end on finalize
159+
if not self.ended_:
160+
self.on_end(state, None)
161+
156162
if len(self.scales_) > 0:
157163
raise ValueError(f"Failed to compress {len(self.scales_)} modules")
158164

159-
self.remove_hooks()
160-
161165
if self.scales_ is not None:
162166
self.scales_.clear()
163167
if self.resolved_mappings_ is not None:

0 commit comments

Comments
 (0)