Skip to content

Commit 0e4e002

Browse files
committed
use qkv hooks
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 0a146f8 commit 0e4e002

File tree

4 files changed

+64
-14
lines changed

4 files changed

+64
-14
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Dict, Optional, Tuple
1+
from functools import partial
2+
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple
23

34
import torch
45
from compressed_tensors.quantization import (
@@ -13,18 +14,24 @@
1314
from compressed_tensors.utils import align_module_device, update_parameter_data
1415
from loguru import logger
1516
from torch.nn import Module
17+
from torch.utils.hooks import RemovableHandle
1618

1719
from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache
1820
from llmcompressor.observers import Observer
1921
from llmcompressor.utils.helpers import getattr_chain
2022

23+
if TYPE_CHECKING:
24+
from llmcompressor.modifiers.utils.hooks import HooksMixin
25+
26+
2127
DEFAULT_MAXSHRINK = 0.20
2228
DEFAULT_PATIENCE = 5
2329
DEFAULT_AVERAGING_CONSTANT = 0.01
2430
DEFAULT_GRID = 100.0
2531
DEFAULT_NORM = 2.4
2632

2733
__all__ = [
34+
"register_calibrate_attn_hooks",
2835
"initialize_observer",
2936
"update_weight_zp_scale",
3037
"calibrate_input_hook",
@@ -205,14 +212,30 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
205212
)
206213

207214

208-
def calibrate_input_hook(module: Module, args: Any):
215+
def register_calibrate_attn_hooks(
216+
modifier: HooksMixin, attention_impl
217+
) -> Set[RemovableHandle]:
218+
return {
219+
modifier.register_hook(
220+
attention_impl, partial(calibrate_input_hook, basename="q"), "query"
221+
),
222+
modifier.register_hook(
223+
attention_impl, partial(calibrate_input_hook, basename="k"), "key"
224+
),
225+
modifier.register_hook(
226+
attention_impl, partial(calibrate_input_hook, basename="v"), "value"
227+
),
228+
}
229+
230+
231+
def calibrate_input_hook(module: Module, args: Any, base_name: str = "input"):
209232
"""
210233
Hook to calibrate input activations.
211234
Will call the observers to update the scales/zp before applying
212235
input QDQ in the module's forward pass.
213236
"""
214237
args = args[0] if isinstance(args, tuple) else args
215-
calibrate_activations(module, value=args, base_name="input")
238+
calibrate_activations(module, value=args, base_name=base_name)
216239

217240

218241
def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
@@ -290,14 +313,6 @@ def initialize_attention_observers(module: Module):
290313
initialize_observer(module, "v", input_args)
291314

292315

293-
def calibrate_attention(
294-
module: Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
295-
):
296-
calibrate_activations(module, value=query, base_name="q")
297-
calibrate_activations(module, value=key, base_name="k")
298-
calibrate_activations(module, value=value, base_name="v")
299-
300-
301316
def apply_calibration_status(module: Module):
302317
scheme = getattr(module, "quantization_scheme", None)
303318
if not scheme:

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
244244

245245
# TODO: attnq
246246
# attention_impl = enable_compressed_attention(model)
247-
# hooks.add(self.register_hook(attention_impl, calibrate_attention, "calib"))
247+
# hooks |= register_calibrate_attn_hooks(self, attention_impl)
248248

249249
for module in model.modules():
250250
if not hasattr(module, "quantization_scheme"):

src/llmcompressor/modifiers/transform/spinquant/base.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,40 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
215215
)
216216

217217
def _create_r3_scheme(self) -> TransformScheme:
218-
raise NotImplementedError()
218+
return (
219+
TransformScheme(
220+
type=self.transform_type,
221+
randomize=self.randomize,
222+
requires_grad=self.learnable,
223+
apply=[
224+
TransformArgs(
225+
targets=[self.mappings.attn],
226+
location="attn_q",
227+
),
228+
TransformArgs(
229+
targets=[self.mappings.attn],
230+
location="attn_k",
231+
),
232+
],
233+
),
234+
)
219235

220236
def _create_r4_scheme(self) -> TransformScheme:
221-
raise NotImplementedError()
237+
return (
238+
TransformScheme(
239+
type=self.transform_type,
240+
randomize=self.randomize,
241+
requires_grad=self.learnable,
242+
apply=[
243+
TransformArgs(
244+
targets=[*self.mappings.mlp_out],
245+
location="input",
246+
),
247+
TransformArgs(
248+
targets=[*self.mappings.mlp_out],
249+
location="weight_input",
250+
inverse=True,
251+
),
252+
],
253+
),
254+
)

src/llmcompressor/modifiers/transform/spinquant/mappings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class SpinQuantMapping(BaseModel):
1111
embedding: str
1212

13+
attn: str
1314
attn_q: str
1415
attn_k: str
1516
attn_v: str
@@ -31,6 +32,7 @@ def cast_to_list(cls, value):
3132

3233
_default_mappings = SpinQuantMapping(
3334
embedding="re:.*embed_tokens$",
35+
attn="re:.*self_attn$",
3436
attn_q="re:.*q_proj$",
3537
attn_k="re:.*k_proj$",
3638
attn_v="re:.*v_proj$",

0 commit comments

Comments
 (0)