-
Notifications
You must be signed in to change notification settings - Fork 201
[Transform] Online Rotations #1651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ba617db
2f5b1c8
3aa35e7
b6c088e
d1eb2a1
3207124
a88ca3c
5bd51df
3c216dd
bbcdc8c
bd7f4d5
f5c2150
dc5c30c
7172c26
9298e82
f77226d
fdb64b5
5daa2d5
0e9af7b
fce83be
a979f8a
4cab29e
f1cc987
a7bb2e2
0cf0188
53ea307
4b4257f
dc7ac1a
8542f8d
b1e637e
b44ac81
7a52b71
f4d7ec6
f18d0e8
cec2914
0a146f8
0e4e002
5aa3586
a9b2f51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,60 @@ | ||||||
# python3 compress_model.py --model_id meta-llama/Llama-3.2-1B-Instruct --transform_type random-hadamard | ||||||
import argparse | ||||||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||
|
||||||
from llmcompressor import oneshot | ||||||
from llmcompressor.modifiers.quantization import QuantizationModifier | ||||||
from llmcompressor.modifiers.transform import SpinQuantModifier | ||||||
from llmcompressor.utils import dispatch_for_generation | ||||||
|
||||||
def parse_args(): | ||||||
parser = argparse.ArgumentParser() | ||||||
parser.add_argument("--model_id", type=str, help="Model stub to compress") | ||||||
parser.add_argument("--transform_type", type=str, default=None, help="Type of transform used in SpinQuantModifier") | ||||||
parser.add_argument("--scheme", type=str, default=None, help="Quantization scheme (e.g. W4A16)") | ||||||
return parser.parse_args() | ||||||
|
||||||
if __name__ == "__main__": | ||||||
args = parse_args() | ||||||
|
||||||
# Select model and load it. | ||||||
MODEL_ID = args.model_id | ||||||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") | ||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | ||||||
|
||||||
# Select number of samples. 512 samples is a good place to start. | ||||||
# Increasing the number of samples can improve accuracy. | ||||||
NUM_CALIBRATION_SAMPLES = 512 | ||||||
MAX_SEQUENCE_LENGTH = 2048 | ||||||
|
||||||
# Configure the quantization algorithm to run. | ||||||
recipe = [] | ||||||
if args.transform_type: | ||||||
recipe.append(SpinQuantModifier(rotations=["R1", "R2"], transform_type=args.transform_type)) | ||||||
|
||||||
if args.scheme: | ||||||
recipe.append(QuantizationModifier(targets="Linear", scheme=args.scheme, ignore=["lm_head"])) | ||||||
|
||||||
# Apply algorithms. | ||||||
oneshot( | ||||||
model=model, | ||||||
recipe=recipe, | ||||||
dataset="ultrachat_200k", | ||||||
splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"}, | ||||||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||||||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||||||
) | ||||||
|
||||||
# Confirm generations of the quantized model look sane. | ||||||
print("\n\n") | ||||||
print("========== SAMPLE GENERATION ==============") | ||||||
dispatch_for_generation(model) | ||||||
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") | ||||||
output = model.generate(input_ids, max_new_tokens=100) | ||||||
print(tokenizer.decode(output[0])) | ||||||
print("==========================================\n\n") | ||||||
|
||||||
# Save to disk compressed. | ||||||
SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current method of constructing
Suggested change
|
||||||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||||||
tokenizer.save_pretrained(SAVE_DIR) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from llmcompressor import oneshot | ||
from llmcompressor.modifiers.quantization import QuantizationModifier | ||
from llmcompressor.modifiers.transform import SpinQuantModifier | ||
from llmcompressor.utils import dispatch_for_generation | ||
|
||
# Select model and load it. | ||
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
MODEL_ID, | ||
torch_dtype="auto", | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, attn_implementation="eager") | ||
|
||
# Select calibration dataset. | ||
DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
DATASET_SPLIT = "train_sft" | ||
|
||
# Select number of samples. 512 samples is a good place to start. | ||
# Increasing the number of samples can improve accuracy. | ||
NUM_CALIBRATION_SAMPLES = 512 | ||
MAX_SEQUENCE_LENGTH = 2048 | ||
|
||
# Load dataset and preprocess. | ||
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
ds = ds.shuffle(seed=42) | ||
|
||
|
||
def preprocess(example): | ||
return { | ||
"text": tokenizer.apply_chat_template( | ||
example["messages"], | ||
tokenize=False, | ||
) | ||
} | ||
|
||
|
||
ds = ds.map(preprocess) | ||
|
||
|
||
# Tokenize inputs. | ||
def tokenize(sample): | ||
return tokenizer( | ||
sample["text"], | ||
padding=False, | ||
max_length=MAX_SEQUENCE_LENGTH, | ||
truncation=True, | ||
add_special_tokens=False, | ||
) | ||
|
||
|
||
ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
||
# Configure the quantization algorithm to run. | ||
# * apply spinquant transforms to model in order to make quantization easier | ||
# * quantize the weights to 4 bit with GPTQ with a group size 128 | ||
recipe = [ | ||
SpinQuantModifier( | ||
rotations=["R1", "R2", "R3", "R4"], transform_type="random-hadamard" | ||
), | ||
# QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), | ||
] | ||
|
||
# Apply algorithms. | ||
oneshot( | ||
model=model, | ||
recipe=recipe, | ||
dataset=ds, | ||
max_seq_length=MAX_SEQUENCE_LENGTH, | ||
num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
) | ||
|
||
# Confirm generations of the quantized model look sane. | ||
print("\n\n") | ||
print("========== SAMPLE GENERATION ==============") | ||
dispatch_for_generation(model) | ||
from llmcompressor.utils import calibration_forward_context | ||
|
||
with calibration_forward_context(model): | ||
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") | ||
output = model.generate(input_ids, max_new_tokens=100) | ||
print(tokenizer.decode(output[0])) | ||
print("==========================================\n\n") | ||
|
||
# Save to disk compressed. | ||
SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16" | ||
model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
tokenizer.save_pretrained(SAVE_DIR) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# flake8: noqa | ||
|
||
from .fuse import * | ||
from .prepare import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Iterable | ||
|
||
import torch | ||
from compressed_tensors import ( | ||
align_module_device, | ||
get_execution_device, | ||
update_offload_parameter, | ||
) | ||
from transformers.models.llama.modeling_llama import LlamaRMSNorm | ||
|
||
__all__ = ["normalize_embedding", "fuse_norm_linears"] | ||
|
||
|
||
PRECISION = torch.float64 | ||
|
||
|
||
def normalize_embedding(embedding: torch.nn.Module): | ||
if isinstance(embedding, (torch.nn.Embedding)): | ||
with align_module_device(embedding): | ||
weight_dtype = embedding.weight.dtype | ||
weight = embedding.weight.to(PRECISION) | ||
new_weight = weight - weight.mean(dim=-1, keepdim=True) | ||
new_weight = new_weight.to(weight_dtype) | ||
|
||
update_offload_parameter(embedding, "weight", new_weight) | ||
|
||
else: | ||
raise ValueError(f"Cannot normalize embedding of type {type(embedding)}") | ||
|
||
|
||
def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): | ||
""" | ||
Fuse a norm layer into subsequent linear layers. This useful for ensuring transform | ||
invariance between norm and linear layers. | ||
|
||
Note that a model cannot be properly trained after its norms have been fused | ||
|
||
:param norm: norm layer whose weight will be fused into subsequent linears | ||
:param linears: linear layers which directly follow the norm layer | ||
""" | ||
if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm)): | ||
for linear in linears: | ||
# NOTE: spinquant does this op in float64 | ||
exec_device = get_execution_device(norm) | ||
with align_module_device(norm, exec_device), align_module_device( | ||
linear, exec_device | ||
): | ||
weight_dtype = linear.weight.dtype | ||
new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION) | ||
new_weight = new_weight.to(weight_dtype) | ||
|
||
update_offload_parameter(linear, "weight", new_weight) | ||
|
||
new_norm_weight = torch.ones_like(norm.weight, device="cpu") | ||
update_offload_parameter(norm, "weight", new_norm_weight) | ||
|
||
else: | ||
raise ValueError(f"Cannot fuse norm of type {type(norm)}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# flake8: noqa | ||
|
||
from .spinquant import SpinQuantModifier |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# flake8: noqa | ||
|
||
from .base import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
--model_id
argument is essential for this script to run. Without it,args.model_id
will beNone
, causingAutoModelForCausalLM.from_pretrained
to fail. Please make this argument required to prevent runtime errors and add a more descriptive help message.