Skip to content

Commit 3207124

Browse files
TransformModifier with SpinQuant R1&R2
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent d1eb2a1 commit 3207124

File tree

6 files changed

+84
-25
lines changed

6 files changed

+84
-25
lines changed

examples/transform/llama3_example.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from datasets import load_dataset
22
from transformers import AutoModelForCausalLM, AutoTokenizer
33

4-
from llmcompressor.modifiers.quantization import GPTQModifier
5-
from llmcompressor.modifiers.transform import TransformModifier
64
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
6+
from llmcompressor.modifiers.transform import TransformModifier
7+
from llmcompressor.utils import dispatch_for_generation
78

89
# Select model and load it.
910
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
@@ -56,8 +57,8 @@ def tokenize(sample):
5657
# Configure the quantization algorithm to run.
5758
# * quantize the weights to 4 bit with GPTQ with a group size 128
5859
recipe = [
59-
TransformModifier(),
60-
GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
60+
TransformModifier(preset_config="LLAMA_SPINQUANT_R1R2"),
61+
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
6162
]
6263

6364
# Apply algorithms.
@@ -70,15 +71,16 @@ def tokenize(sample):
7071
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
7172
)
7273

73-
# Confirm generations of the quantized model look sane.
74-
print("\n\n")
75-
print("========== SAMPLE GENERATION ==============")
76-
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
77-
output = model.generate(input_ids, max_new_tokens=100)
78-
print(tokenizer.decode(output[0]))
79-
print("==========================================\n\n")
74+
# # Confirm generations of the quantized model look sane.
75+
# print("\n\n")
76+
# print("========== SAMPLE GENERATION ==============")
77+
# dispatch_for_generation(model)
78+
# input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
79+
# output = model.generate(input_ids, max_new_tokens=100)
80+
# print(tokenizer.decode(output[0]))
81+
# print("==========================================\n\n")
8082

8183
# Save to disk compressed.
82-
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
84+
SAVE_DIR = MODEL_ID.split("/")[1] + "-transform-quant-w4a16"
8385
model.save_pretrained(SAVE_DIR, save_compressed=True)
8486
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# flake8: noqa
22

33
from .transform import TransformModifier
4+
from .transform.presets import TRANSFORM_PRESETS
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .quip import QUIP
2+
from .spinquant import LLAMA_SPINQUANT, LLAMA_SPINQUANT_R1R2
3+
4+
TRANSFORM_PRESETS = {
5+
"QUIP": QUIP,
6+
"LLAMA_SPINQUANT": LLAMA_SPINQUANT,
7+
"LLAMA_SPINQUANT_R1R2": LLAMA_SPINQUANT_R1R2,
8+
}

src/llmcompressor/modifiers/transform/template/spinquant.py renamed to src/llmcompressor/modifiers/transform/presets/spinquant.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from compressed_tensors.transform import TransformArgs, TransformConfig, TransformScheme
22

3+
# Ref: https://arxiv.org/pdf/2405.16406 Fig 1
4+
5+
# All rotations
36
LLAMA_SPINQUANT = TransformConfig(
47
transform_groups={
58
"R1": TransformScheme(
@@ -62,3 +65,43 @@
6265
),
6366
}
6467
)
68+
69+
70+
# Mergeable rotations R1 and R2 only
71+
LLAMA_SPINQUANT_R1R2 = TransformConfig(
72+
config_groups={
73+
"R1": TransformScheme(
74+
type="hadamard",
75+
apply=[
76+
TransformArgs(
77+
targets=["embed_tokens", "o_proj", "down_proj"],
78+
location="weight_output",
79+
),
80+
TransformArgs(
81+
targets=[
82+
"q_proj",
83+
"k_proj",
84+
"v_proj",
85+
"up_proj",
86+
"gate_proj",
87+
"lm_head",
88+
],
89+
location="weight_input",
90+
inverse=True,
91+
),
92+
],
93+
),
94+
"R2": TransformScheme(
95+
type="hadamard",
96+
apply=[
97+
TransformArgs(
98+
targets=["v_proj"],
99+
location="weight_output",
100+
),
101+
TransformArgs(
102+
targets=["o_proj"], location="weight_input", inverse=True
103+
),
104+
],
105+
),
106+
}
107+
)

src/llmcompressor/modifiers/transform/transform.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
1-
from typing import Dict, Optional
1+
from typing import Optional
22

3-
from compressed_tensors.transform import TransformScheme, apply_transform_config
3+
from compressed_tensors.transform import TransformConfig, apply_transform_config
4+
from pydantic import ValidationError, model_validator
45

56
from llmcompressor.core import Event, EventType, State
67
from llmcompressor.modifiers import Modifier
7-
8-
from .template.quip import QUIP
8+
from llmcompressor.modifiers.transform.presets import TRANSFORM_PRESETS
99

1010

1111
class TransformModifier(Modifier):
1212
preset_config: Optional[str] = None
13-
config_groups: Optional[Dict[str, TransformScheme]] = None
13+
config: Optional[TransformConfig] = None
1414

1515
# model validator to validate both preset and config groups are not provided
16+
@model_validator(mode="after")
17+
def validate_model_after(model: "TransformModifier") -> "TransformModifier":
18+
if model.preset_config is None and model.config is None:
19+
raise ValidationError("Either a config or a preset_config must be provided")
20+
21+
if model.preset_config is not None:
22+
if model.preset_config not in TRANSFORM_PRESETS:
23+
raise ValidationError(
24+
f"Invalid preset_config '{model.preset_config}' "
25+
f"must be in {TRANSFORM_PRESETS.keys()}"
26+
)
27+
model.config = TRANSFORM_PRESETS[model.preset_config]
1628

1729
def on_initialize(self, state: State, **kwargs) -> bool:
18-
if self.preset_config is not None:
19-
# import config template and customize to model
20-
pass
21-
22-
# config = TransformConfig(config_groups=self.config_groups)
23-
config = QUIP
24-
25-
apply_transform_config(state.model, config)
30+
apply_transform_config(state.model, self.config)
2631

2732
return True
2833

0 commit comments

Comments
 (0)