Skip to content

Commit 9f58887

Browse files
kylesayrsdsikka
andauthored
Implement HooksMixin (#917)
* Implement HooksMixin Signed-off-by: Kyle Sayers <[email protected]> * add docstring Signed-off-by: Kyle Sayers <[email protected]> * integrate with smoothquant Signed-off-by: Kyle Sayers <[email protected]> * integrate with QuantizationModifier Signed-off-by: Kyle Sayers <[email protected]> * update hooks in tests Signed-off-by: Kyle Sayers <[email protected]> * integrate with wanda Signed-off-by: Kyle Sayers <[email protected]> * integrate with magnitude and constant Signed-off-by: Kyle Sayers <[email protected]> * integrate with SparseGPTModifier Signed-off-by: Kyle Sayers <[email protected]> * add hooksmixin to modifier Signed-off-by: Kyle Sayers <[email protected]> * nits Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 1830382 commit 9f58887

File tree

11 files changed

+292
-172
lines changed

11 files changed

+292
-172
lines changed

src/llmcompressor/modifiers/modifier.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
from abc import ABC, abstractmethod
1+
from abc import abstractmethod
22
from typing import Optional
33

4-
from pydantic import BaseModel
5-
64
from llmcompressor.core.events import Event, EventType
75
from llmcompressor.core.state import State
86
from llmcompressor.modifiers.interface import ModifierInterface
7+
from llmcompressor.modifiers.utils.hooks import HooksMixin
98

109
__all__ = ["Modifier"]
1110

1211

13-
class Modifier(BaseModel, ModifierInterface, ABC):
12+
class Modifier(ModifierInterface, HooksMixin):
1413
"""
1514
A base class for all modifiers to inherit from.
1615
Modifiers are used to modify the training process for a model.

src/llmcompressor/modifiers/obcq/base.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
23

34
import numpy as np
@@ -130,7 +131,8 @@ def initialize_compression(
130131
"Inferring layer-wise sparsities from "
131132
f"{len(dataloader)} calibration samples..."
132133
)
133-
self.sparsity = self._infer_layer_sparsity(dataloader)
134+
activations = self._get_activations(dataloader)
135+
self.sparsity = self._infer_layer_sparsity(activations)
134136
self._validate_layerwise_sparsity()
135137

136138
for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
@@ -254,19 +256,17 @@ def _infer_mask_block_size(self):
254256

255257
self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))
256258

257-
def _infer_layer_sparsity(self, calibration_dataloader):
258-
acts = _get_activations(self.model, calibration_dataloader)
259+
def _infer_layer_sparsity(self, activations):
259260
sparsegpt_groups = {}
260261
for name, layer in self.compressible_layers_.items():
261262
prunable_layers = get_prunable_layers(layer)
262263
z = [
263-
m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0)
264+
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
264265
for n, m in prunable_layers.items()
265266
]
266267
sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z])
267268

268-
acts = None
269-
del acts
269+
del activations
270270
torch.cuda.empty_cache()
271271

272272
outlier_ratios = {}
@@ -300,36 +300,34 @@ def _infer_layer_sparsity(self, calibration_dataloader):
300300
logger.info(f"Sparsity for {k}: {sparsities[k]}")
301301
return sparsities
302302

303+
@torch.no_grad()
304+
def _get_activations(self, data_loader, nsamples=128):
305+
self.model.eval()
306+
acts = {}
307+
308+
def save_acts(module, input, name):
309+
if isinstance(input, tuple):
310+
input = input[0]
311+
if name not in acts:
312+
acts[name] = (
313+
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
314+
)
315+
else:
316+
acts[name] += (
317+
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
318+
)
319+
320+
for name, mod in self.model.named_modules():
321+
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
322+
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
323+
324+
device = next(self.model.parameters()).device
325+
for batch in tqdm(data_loader):
326+
batch = {k: v.to(device) for k, v in batch.items()}
327+
self.model(**batch)
328+
batch = None
329+
torch.cuda.empty_cache()
303330

304-
@torch.no_grad()
305-
def _get_activations(model, data_loader, nsamples=128):
306-
import functools
307-
308-
model.eval()
309-
acts = {}
310-
311-
def save_acts(module, input, name):
312-
if isinstance(input, tuple):
313-
input = input[0]
314-
if name not in acts:
315-
acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
316-
else:
317-
acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
318-
319-
hooks = []
320-
for name, mod in model.named_modules():
321-
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
322-
hooks.append(
323-
mod.register_forward_pre_hook(functools.partial(save_acts, name=name))
324-
)
325-
device = next(model.parameters()).device
326-
for batch in tqdm(data_loader):
327-
batch = {k: v.to(device) for k, v in batch.items()}
328-
model(**batch)
329-
batch = None
330-
torch.cuda.empty_cache()
331-
332-
for h in hooks:
333-
h.remove()
331+
self.remove_hooks()
334332

335-
return acts
333+
return acts

src/llmcompressor/modifiers/pruning/utils/pytorch/layer_mask.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
from typing import Dict
33

44
import torch
5-
from pydantic import BaseModel
65
from torch.nn import Parameter
7-
from torch.utils.hooks import RemovableHandle
86

97
from llmcompressor.core import ModelParameterizedLayer
8+
from llmcompressor.modifiers.utils.hooks import HooksMixin
109

1110
__all__ = ["LayerParamMasking", "param_mask_name"]
1211

@@ -39,11 +38,9 @@ class ParameterizedLayerMaskSettings:
3938
use_hooks: bool = False
4039

4140

42-
class LayerParamMasking(BaseModel):
41+
class LayerParamMasking(HooksMixin):
4342
_mask_settings: Dict[str, ParameterizedLayerMaskSettings] = {}
4443
_masked_layer_params: Dict[str, ModelParameterizedLayer] = {}
45-
_forward_hooks: Dict[str, RemovableHandle] = {}
46-
_backward_hooks: Dict[str, RemovableHandle] = {}
4744
enabled_: bool = False
4845

4946
def add_mask(
@@ -100,12 +97,8 @@ def _backward_hook_fn(gradients):
10097

10198
return gradients
10299

103-
self._forward_hooks[layer_param_name] = (
104-
parameterized_layer.layer.register_forward_hook(_forward_hook_fn)
105-
)
106-
self._backward_hooks[layer_param_name] = (
107-
parameterized_layer.param.register_hook(_backward_hook_fn)
108-
)
100+
self.register_hook(parameterized_layer.layer, _forward_hook_fn, "forward")
101+
self.register_hook(parameterized_layer.param, _backward_hook_fn, "")
109102

110103
def update_mask(
111104
self,
@@ -131,11 +124,7 @@ def remove_mask(self, layer_param_name: str):
131124
del self._mask_settings[layer_param_name]
132125

133126
if mask_settings.use_hooks:
134-
self._forward_hooks[layer_param_name].remove()
135-
self._backward_hooks[layer_param_name].remove()
136-
137-
del self._forward_hooks[layer_param_name]
138-
del self._backward_hooks[layer_param_name]
127+
self.remove_hooks()
139128

140129
def apply_mask_weight(self, layer_param_name: str):
141130
if not self.enabled_:

src/llmcompressor/modifiers/pruning/wanda/base.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
23

34
import numpy as np
@@ -121,7 +122,8 @@ def initialize_compression(
121122
"Inferring layer-wise sparsities from "
122123
f"{len(dataloader) if dataloader else 0} calibration samples..."
123124
)
124-
self.sparsity = self._infer_layer_sparsity(dataloader)
125+
activations = self._get_activations(dataloader)
126+
self.sparsity = self._infer_layer_sparsity(activations)
125127
self._validate_layerwise_sparsity()
126128

127129
for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
@@ -224,19 +226,17 @@ def _infer_mask_block_size(self):
224226

225227
self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))
226228

227-
def _infer_layer_sparsity(self, calibration_dataloader):
228-
acts = _get_activations(self.model, calibration_dataloader)
229+
def _infer_layer_sparsity(self, activations):
229230
wanda = {}
230231
for name, layer in self.compressible_layers_.items():
231232
prunable_layers = get_prunable_layers(layer)
232233
z = [
233-
m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0)
234+
m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0)
234235
for n, m in prunable_layers.items()
235236
]
236237
wanda[name] = torch.cat([item.flatten().cpu() for item in z])
237238

238-
acts = None
239-
del acts
239+
del activations
240240
torch.cuda.empty_cache()
241241

242242
outlier_ratios = {}
@@ -268,36 +268,34 @@ def _infer_layer_sparsity(self, calibration_dataloader):
268268
logger.info(f"Sparsity for {k}: {sparsities[k]}")
269269
return sparsities
270270

271+
@torch.no_grad()
272+
def _get_activations(self, data_loader, nsamples=128):
273+
self.model.eval()
274+
acts = {}
275+
276+
def save_acts(module, input, name):
277+
if isinstance(input, tuple):
278+
input = input[0]
279+
if name not in acts:
280+
acts[name] = (
281+
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
282+
)
283+
else:
284+
acts[name] += (
285+
1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
286+
)
287+
288+
for name, mod in self.model.named_modules():
289+
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
290+
self.register_hook(mod, partial(save_acts, name=name), "forward_pre")
291+
292+
device = next(self.model.parameters()).device
293+
for batch in tqdm(data_loader):
294+
batch = {k: v.to(device) for k, v in batch.items()}
295+
self.model(**batch)
296+
batch = None
297+
torch.cuda.empty_cache()
271298

272-
@torch.no_grad()
273-
def _get_activations(model, data_loader, nsamples=128):
274-
import functools
275-
276-
model.eval()
277-
acts = {}
278-
279-
def save_acts(module, input, name):
280-
if isinstance(input, tuple):
281-
input = input[0]
282-
if name not in acts:
283-
acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
284-
else:
285-
acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
286-
287-
hooks = []
288-
for name, mod in model.named_modules():
289-
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
290-
hooks.append(
291-
mod.register_forward_pre_hook(functools.partial(save_acts, name=name))
292-
)
293-
device = next(model.parameters()).device
294-
for batch in tqdm(data_loader):
295-
batch = {k: v.to(device) for k, v in batch.items()}
296-
model(**batch)
297-
batch = None
298-
torch.cuda.empty_cache()
299-
300-
for h in hooks:
301-
h.remove()
299+
self.remove_hooks()
302300

303-
return acts
301+
return acts

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Any, Dict, Optional, Tuple
22

33
import torch
44
from compressed_tensors.quantization import QuantizationStatus, is_attention_module
@@ -146,71 +146,57 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
146146
)
147147

148148

149-
def calibrate_input_hook():
149+
def calibrate_input_hook(module: Module, args: Any):
150150
"""
151151
Hook to calibrate input activations.
152152
Will call the observers to update the scales/zp before applying
153153
input QDQ in the module's forward pass.
154154
"""
155+
args = args[0] if isinstance(args, tuple) else args
156+
calibrate_activations(module, value=args, base_name="input")
155157

156-
def hook_fn(module: Module, inp):
157-
inp = inp[0] if isinstance(inp, tuple) else inp
158-
calibrate_activations(module, value=inp, base_name="input")
159158

160-
return hook_fn
161-
162-
163-
def calibrate_output_hook():
159+
def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
164160
"""
165161
Hook to calibrate output activations.
166162
Will call the observers to update the scales/zp before applying
167163
output QDQ.
168164
"""
169-
170-
def hook_fn(module: Module, inp, output: torch.Tensor):
171-
calibrate_activations(
172-
module,
173-
value=output,
174-
base_name="output",
175-
)
176-
output = forward_quantize(
177-
module=module,
178-
value=output,
179-
base_name="output",
180-
args=module.quantization_scheme.output_activations,
181-
)
182-
return output
183-
184-
return hook_fn
165+
calibrate_activations(
166+
module,
167+
value=output,
168+
base_name="output",
169+
)
170+
output = forward_quantize(
171+
module=module,
172+
value=output,
173+
base_name="output",
174+
args=module.quantization_scheme.output_activations,
175+
)
176+
return output
185177

186178

187-
def calibrate_kv_cache_input_hook():
179+
def calibrate_kv_cache_input_hook(
180+
module: Module, args: Any, kwargs: Dict[str, Any]
181+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
188182
"""
189183
Hook to update inputs to attention layers when running
190184
kv_cache quantization. Will update the passed in
191185
kv_cache to singleton QuantizedKVParameterCache.
192186
"""
187+
kv_cache = getattr(module, "kv_cache")
188+
kwargs["past_key_value"] = kv_cache
189+
kwargs["use_cache"] = False
190+
return args, kwargs
193191

194-
def hook_fn(module: Module, args, kwargs):
195-
kv_cache = getattr(module, "kv_cache")
196-
kwargs["past_key_value"] = kv_cache
197-
kwargs["use_cache"] = False
198-
return args, kwargs
199-
200-
return hook_fn
201192

202-
203-
def calibrate_kv_cache_output_hook():
193+
def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor):
204194
"""
205195
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
206196
"""
207-
208-
def hook_fn(module: Module, inpt, output: torch.Tensor):
209-
kv_cache = getattr(module, "kv_cache")
210-
update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale")
211-
update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale")
212-
213-
return hook_fn
197+
kv_cache = getattr(module, "kv_cache")
198+
update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale")
199+
update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale")
214200

215201

216202
def set_unset_kv_cache(module: Module):

0 commit comments

Comments
 (0)