Skip to content

Commit ac7dbcd

Browse files
kylesayrsbrian-dellabetta
authored andcommitted
add example, cleanup
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 490b987 commit ac7dbcd

File tree

3 files changed

+90
-3
lines changed

3 files changed

+90
-3
lines changed

examples/transform/quip_example.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
from llmcompressor.modifiers.transform import QuIPModifier
7+
from llmcompressor.utils import dispatch_for_generation
8+
9+
# Select model and load it.
10+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
11+
12+
model = AutoModelForCausalLM.from_pretrained(
13+
MODEL_ID,
14+
torch_dtype="auto",
15+
)
16+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
17+
18+
# Select calibration dataset.
19+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
20+
DATASET_SPLIT = "train_sft"
21+
22+
# Select number of samples. 512 samples is a good place to start.
23+
# Increasing the number of samples can improve accuracy.
24+
NUM_CALIBRATION_SAMPLES = 512
25+
MAX_SEQUENCE_LENGTH = 2048
26+
27+
# Load dataset and preprocess.
28+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
29+
ds = ds.shuffle(seed=42)
30+
31+
32+
def preprocess(example):
33+
return {
34+
"text": tokenizer.apply_chat_template(
35+
example["messages"],
36+
tokenize=False,
37+
)
38+
}
39+
40+
41+
ds = ds.map(preprocess)
42+
43+
44+
# Tokenize inputs.
45+
def tokenize(sample):
46+
return tokenizer(
47+
sample["text"],
48+
padding=False,
49+
max_length=MAX_SEQUENCE_LENGTH,
50+
truncation=True,
51+
add_special_tokens=False,
52+
)
53+
54+
55+
ds = ds.map(tokenize, remove_columns=ds.column_names)
56+
57+
# Configure the quantization algorithm to run.
58+
# * apply spinquant transforms to model in order to make quantization easier
59+
# * quantize the weights to 4 bit with GPTQ with a group size 128
60+
recipe = [
61+
QuIPModifier(transform_type="random-hadamard"),
62+
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
63+
]
64+
65+
# Apply algorithms.
66+
oneshot(
67+
model=model,
68+
recipe=recipe,
69+
dataset=ds,
70+
max_seq_length=MAX_SEQUENCE_LENGTH,
71+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
72+
pipeline="basic",
73+
)
74+
75+
# Confirm generations of the quantized model look sane.
76+
print("\n\n")
77+
print("========== SAMPLE GENERATION ==============")
78+
dispatch_for_generation(model)
79+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
80+
output = model.generate(input_ids, max_new_tokens=100)
81+
print(tokenizer.decode(output[0]))
82+
print("==========================================\n\n")
83+
84+
# Save to disk compressed.
85+
SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16"
86+
model.save_pretrained(SAVE_DIR, save_compressed=True)
87+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# flake8: noqa
22

3+
from .quip import QuIPModifier
34
from .spinquant import SpinQuantModifier
4-
from .quip import QuIPModifier

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, List, Literal, Optional, Union
1+
from typing import List, Literal, Optional, Union
22

33
from compressed_tensors.transform import (
44
TransformArgs,
@@ -24,7 +24,7 @@ class QuIPModifier(Modifier):
2424
loss induced by quantization. This is achived through "rotating" weights and
2525
activations into a space with a smaller dynamic range of values, thus decreasing
2626
the range of scales required for quantization.
27-
27+
2828
QuIP and QuIP# apply transforms to every linear layer, two of which are fused into
2929
the model weights and two of which remain as online rotations computed at runtime.
3030

0 commit comments

Comments
 (0)