Skip to content

Commit 490b987

Browse files
kylesayrsbrian-dellabetta
authored andcommitted
implement quip
Signed-off-by: Kyle Sayers <[email protected]>
1 parent a4abb3d commit 490b987

File tree

3 files changed

+135
-0
lines changed

3 files changed

+135
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# flake8: noqa
22

33
from .spinquant import SpinQuantModifier
4+
from .quip import QuIPModifier
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .base import *
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from typing import Iterable, List, Literal, Optional, Union
2+
3+
from compressed_tensors.transform import (
4+
TransformArgs,
5+
TransformConfig,
6+
TransformScheme,
7+
apply_transform_config,
8+
)
9+
from pydantic import Field, ValidationInfo, field_validator
10+
11+
from llmcompressor.core import Event, EventType, State
12+
from llmcompressor.modifiers import Modifier
13+
14+
__all__ = ["QuIPModifier"]
15+
16+
17+
class QuIPModifier(Modifier):
18+
"""
19+
Implements the transforms according to
20+
[QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396) # noqa: E501
21+
[QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304) # noqa: E501
22+
23+
Transforms (rotations) are extra layers added to a model which reduce the accuracy
24+
loss induced by quantization. This is achived through "rotating" weights and
25+
activations into a space with a smaller dynamic range of values, thus decreasing
26+
the range of scales required for quantization.
27+
28+
QuIP and QuIP# apply transforms to every linear layer, two of which are fused into
29+
the model weights and two of which remain as online rotations computed at runtime.
30+
31+
:param transform_type: The type of transform to apply to the model.
32+
`"hadamard"` has the least performance cost but only supports sizes which are
33+
powers of power of two.
34+
`"random-matrix"` has more performance cost, but supports a much larger set of
35+
sizes.
36+
`"random-matrix"` has the greatest performance cost, but supports any size
37+
:param randomize: If true, create distinct transforms for each application
38+
:param learnable: If true, attach gradients to transform weights for training
39+
:param ignore: Modules to ignore when attaching transforms
40+
:param transform_config: Optional transform config for overriding provided arguments
41+
"""
42+
43+
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
44+
default="hadamard", exclude=True
45+
)
46+
randomize: bool = Field(default=False, exclude=True)
47+
learnable: bool = Field(default=False, exclude=True)
48+
ignore: Union[str, List[str]] = Field(default="lm_head", exclude=True)
49+
50+
# optional override for more fine-grained control
51+
# also included in recipe serialization
52+
transform_config: Optional[TransformConfig] = Field(default=None, repr=False)
53+
54+
@field_validator("randomize", "learnable", mode="before")
55+
def validate_not_implemented(cls, value, info: ValidationInfo):
56+
raise NotImplementedError(f"{info.field_name} is not supported right now")
57+
58+
def on_initialize(self, state: State, **kwargs) -> bool:
59+
if self.transform_config is not None:
60+
return True
61+
62+
self.transform_config = self._create_config()
63+
return True
64+
65+
def on_start(self, state: State, event: Event, **kwargs):
66+
self.started_ = True
67+
68+
apply_transform_config(state.model, self.transform_config)
69+
70+
def on_event(self, state: State, event: Event, **kwargs):
71+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
72+
if not self.started_:
73+
self.on_start(state, None)
74+
75+
elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
76+
pass
77+
78+
elif event.type_ == EventType.CALIBRATION_EPOCH_END:
79+
if not self.ended_:
80+
self.on_end(state, None)
81+
82+
def on_end(self, state: State, event: Event, **kwargs):
83+
self.ended_ = True
84+
85+
def on_finalize(self, state: State, **kwargs) -> bool:
86+
if not self.ended_:
87+
self.on_end(state, None)
88+
89+
return True
90+
91+
def _create_config(self) -> TransformConfig:
92+
return TransformConfig(
93+
config_groups={
94+
"v": TransformScheme(
95+
type=self.transform_type,
96+
apply=[
97+
TransformArgs(
98+
targets=["Linear"],
99+
location="input", # non-mergable
100+
ignore=self.ignore,
101+
),
102+
TransformArgs(
103+
targets=["Linear"],
104+
location="weight_input",
105+
inverse=True,
106+
ignore=self.ignore,
107+
),
108+
],
109+
randomize=self.randomize,
110+
requires_grad=self.learnable,
111+
),
112+
"u": TransformScheme(
113+
type=self.transform_type,
114+
apply=[
115+
TransformArgs(
116+
targets=["Linear"],
117+
location="weight_output",
118+
ignore=self.ignore,
119+
),
120+
TransformArgs(
121+
targets=["Linear"],
122+
location="output", # non-mergable
123+
inverse=True,
124+
ignore=self.ignore,
125+
),
126+
],
127+
randomize=self.randomize,
128+
requires_grad=self.learnable,
129+
),
130+
}
131+
)

0 commit comments

Comments
 (0)