Skip to content

Commit 6415131

Browse files
kylesayrsbrian-dellabettadsikka
authored
[Transform] QuIP Modifier (#1648)
## Purpose ## * Enable quip-style transforms ## Prerequisites ## * neuralmagic/compressed-tensors#370 * neuralmagic/compressed-tensors#412 * neuralmagic/compressed-tensors#414 ## Changes ## * Added `quip_example.py` to examples folder * As made clear in the disclaimer, this example requires minimum versions of compressed-tensors and transformers to run * Added `QuIPModifier` which handles the construction of a quip-style transform config ## Testing ## * Added modifier serialization and correctness tests ## Evaluation ## Evaluation performed by @brian-dellabetta Evals on Llama 3.2 1B with Quip (num_fewshot 8, limit 1000 to be compatible with results [here](https://github.com/vllm-project/llm-compressor/pull/1243/files#diff-bdc27f23c0dc2da352d5c83abdc0f267873edf4d36f88474038b975df75bd8c3R38-R64)) : | Strat | gsm8k,strict | gsm8k_llama,strict | |-|-|-| | FP16 | .352 | .323 | | Quip | .348 | .322 | | W4A16 | .180 | .017 | | Quip+W4A16 | .213 | .141 | ## Follow Ups ## * Infer data free pipeline, even if a transform modifier is included * Modify example to use GPTQ once basic evaluation has been performed --------- Signed-off-by: Kyle Sayers <[email protected]> Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Brian Dellabetta <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent dbc4bc5 commit 6415131

File tree

6 files changed

+209
-7
lines changed

6 files changed

+209
-7
lines changed

examples/transform/quip_example.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
NOTE: Models produced by this example will not be runnable in vLLM without
3+
the following changes: https://github.com/vllm-project/vllm/pull/22486
4+
"""
5+
6+
from transformers import AutoModelForCausalLM, AutoTokenizer
7+
8+
from llmcompressor import oneshot
9+
from llmcompressor.modifiers.quantization import QuantizationModifier
10+
from llmcompressor.modifiers.transform import QuIPModifier
11+
from llmcompressor.utils import dispatch_for_generation
12+
13+
# Select model and load it.
14+
# NOTE: because the datafree pipeline is being used in this
15+
# example, you can use additional GPUs to support larger models
16+
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
17+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
18+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
19+
20+
# Configure the quantization algorithm to run.
21+
# * apply spinquant transforms to model in order to make quantization easier
22+
# * quantize the weights to 4 bit with a group size 128
23+
recipe = [
24+
QuIPModifier(transform_type="random-hadamard"),
25+
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
26+
]
27+
28+
# Apply algorithms.
29+
oneshot(model=model, recipe=recipe, pipeline="datafree")
30+
31+
# Confirm generations of the quantized model look sane.
32+
print("\n\n")
33+
print("========== SAMPLE GENERATION ==============")
34+
dispatch_for_generation(model)
35+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
36+
output = model.generate(input_ids, max_new_tokens=100)
37+
print(tokenizer.decode(output[0]))
38+
print("==========================================\n\n")
39+
40+
# Save to disk compressed.
41+
SAVE_DIR = MODEL_ID.split("/")[1] + "-quip-w4a16"
42+
model.save_pretrained(SAVE_DIR, save_compressed=True)
43+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# flake8: noqa
22

3+
from .quip import QuIPModifier
34
from .spinquant import SpinQuantModifier
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .base import *
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from typing import List, Literal, Optional, Union
2+
3+
import torch
4+
from compressed_tensors.transform import (
5+
TransformArgs,
6+
TransformConfig,
7+
TransformScheme,
8+
apply_transform_config,
9+
)
10+
from compressed_tensors.utils import TorchDtype
11+
from pydantic import Field, ValidationInfo, field_validator
12+
13+
from llmcompressor.core import Event, EventType, State
14+
from llmcompressor.modifiers import Modifier
15+
16+
__all__ = ["QuIPModifier"]
17+
18+
19+
class QuIPModifier(Modifier):
20+
"""
21+
Implements the transforms according to
22+
[QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396)
23+
[QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304)
24+
25+
Transforms (rotations) are extra layers added to a model which reduce the accuracy
26+
loss induced by quantization. This is achieved through "rotating" weights and
27+
activations into a space with a smaller dynamic range of values, thus decreasing
28+
the range of scales required for quantization.
29+
30+
QuIP and QuIP# apply transforms to every linear layer, two of which are fused into
31+
the model weights and two of which remain as online rotations computed at runtime.
32+
33+
Lifecycle:
34+
- on_initialize
35+
- infer SpinQuantMappings & NormMappings
36+
- as needed, create transform schemes for R1, R2, R3, & R4
37+
- on_start
38+
- normalize embeddings
39+
- fuse norm layers into subsequent Linear layers
40+
- apply TransformConfig
41+
- fuse transforms into weights for mergeable transforms
42+
- add hooks for online transforms
43+
- on sequential epoch end
44+
- on_end
45+
- on_finalize
46+
47+
:param transform_type: The type of transform to apply to the model.
48+
`"hadamard"` has the least performance cost but only supports sizes which are
49+
powers of power of two.
50+
`"random-hadamard"` has more performance cost, but supports a much larger set of
51+
sizes.
52+
`"random-matrix"` has the greatest performance cost, but supports any size
53+
:param randomize: If true, create distinct transforms for each application
54+
:param learnable: If true, attach gradients to transform weights for training
55+
:param precision: Precision at which all transforms should be applied. This applies
56+
to both weight fusing and online rotations
57+
:param ignore: Modules to ignore when attaching transforms
58+
:param transform_config: Optional transform config for overriding provided arguments
59+
""" # noqa: E501
60+
61+
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
62+
default="random-hadamard"
63+
)
64+
targets: Union[List[str], str] = Field(default="str")
65+
randomize: bool = Field(default=False)
66+
learnable: bool = Field(default=False)
67+
precision: TorchDtype = Field(default=torch.float64)
68+
ignore: Union[str, List[str]] = Field(default="lm_head")
69+
70+
# optional override for more fine-grained control
71+
# also included in recipe serialization
72+
transform_config: Optional[TransformConfig] = Field(default=None, repr=False)
73+
74+
@field_validator("randomize", "learnable", mode="before")
75+
def validate_not_implemented(cls, value, info: ValidationInfo):
76+
if value:
77+
raise NotImplementedError(f"{info.field_name} is not supported right now")
78+
return value
79+
80+
def on_initialize(self, state: State, **kwargs) -> bool:
81+
if self.transform_config is not None:
82+
return True
83+
84+
self.transform_config = self._create_config()
85+
return True
86+
87+
def on_start(self, state: State, event: Event, **kwargs):
88+
self.started_ = True
89+
90+
apply_transform_config(state.model, self.transform_config)
91+
92+
def on_event(self, state: State, event: Event, **kwargs):
93+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
94+
if not self.started_:
95+
self.on_start(state, None)
96+
97+
elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
98+
pass
99+
100+
elif event.type_ == EventType.CALIBRATION_EPOCH_END:
101+
if not self.ended_:
102+
self.on_end(state, None)
103+
104+
def on_end(self, state: State, event: Event, **kwargs):
105+
self.ended_ = True
106+
107+
def on_finalize(self, state: State, **kwargs) -> bool:
108+
if not self.ended_:
109+
self.on_end(state, None)
110+
111+
return True
112+
113+
def _create_config(self) -> TransformConfig:
114+
return TransformConfig(
115+
config_groups={
116+
"v": TransformScheme(
117+
type=self.transform_type,
118+
apply=[
119+
TransformArgs(
120+
targets=self.targets,
121+
location="input", # non-mergable
122+
ignore=self.ignore,
123+
),
124+
TransformArgs(
125+
targets=self.targets,
126+
location="weight_input",
127+
inverse=True,
128+
ignore=self.ignore,
129+
),
130+
],
131+
randomize=self.randomize,
132+
requires_grad=self.learnable,
133+
precision=self.precision,
134+
),
135+
"u": TransformScheme(
136+
type=self.transform_type,
137+
apply=[
138+
TransformArgs(
139+
targets=self.targets,
140+
location="weight_output",
141+
ignore=self.ignore,
142+
),
143+
TransformArgs(
144+
targets=self.targets,
145+
location="output", # non-mergable
146+
inverse=True,
147+
ignore=self.ignore,
148+
),
149+
],
150+
randomize=self.randomize,
151+
requires_grad=self.learnable,
152+
precision=self.precision,
153+
),
154+
}
155+
)

tests/llmcompressor/modifiers/transform/test_correctness.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from transformers import AutoModelForCausalLM
66

77
from llmcompressor.core import State
8-
from llmcompressor.modifiers.transform import SpinQuantModifier
8+
from llmcompressor.modifiers.transform import QuIPModifier, SpinQuantModifier
99
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
1010
untie_word_embeddings,
1111
)
@@ -20,10 +20,10 @@
2020
@pytest.mark.parametrize(
2121
"modifier,model_dtype,precision,exp_mse",
2222
[
23-
# (QuIPModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019
24-
# (QuIPModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0022
25-
# (QuIPModifier, torch.float32, torch.float32, 5e-10), # 1.0e-10
26-
# (QuIPModifier, torch.float32, torch.float64, 5e-11), # 2.7e-11
23+
(QuIPModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019
24+
(QuIPModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0022
25+
(QuIPModifier, torch.float32, torch.float32, 5e-10), # 1.0e-10
26+
(QuIPModifier, torch.float32, torch.float64, 5e-11), # 2.7e-11
2727
(SpinQuantModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0030
2828
(SpinQuantModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0029
2929
(SpinQuantModifier, torch.float32, torch.float32, 5e-4), # 4e-4

tests/llmcompressor/modifiers/transform/test_serialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import pytest
22

3-
from llmcompressor.modifiers.transform import SpinQuantModifier
3+
from llmcompressor.modifiers.transform import QuIPModifier, SpinQuantModifier
44

55

6-
@pytest.mark.parametrize("modifier", [SpinQuantModifier])
6+
@pytest.mark.parametrize("modifier", [SpinQuantModifier, QuIPModifier])
77
def test_reload(modifier):
88
instance = modifier(transform_type="hadamard")
99
dump = instance.model_dump()

0 commit comments

Comments
 (0)