-
Notifications
You must be signed in to change notification settings - Fork 453
Expand file tree
/
Copy pathllama3_8b_2of4.py
More file actions
113 lines (91 loc) · 3.22 KB
/
llama3_8b_2of4.py
File metadata and controls
113 lines (91 loc) · 3.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.obcq import SparseGPTModifier
from llmcompressor.modifiers.quantization import QuantizationModifier
# Configuration
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description="Apply compression to a model")
parser.add_argument("--fp8", action="store_true", help="Enable FP8 compression")
return parser.parse_args()
def preprocess(example):
"""Preprocess dataset examples."""
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
def tokenize(sample):
"""Tokenize dataset examples."""
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)
def get_recipe(fp8_enabled):
"""Generate the compression recipe and save directory based on the FP8 flag."""
base_recipe = [
SparseGPTModifier(
sparsity=0.5,
mask_structure="2:4",
sequential_update=True,
targets=[r"re:model.layers.\d*$"],
)
]
save_dir = MODEL_ID.split("/")[1] + "2of4-sparse"
if fp8_enabled:
base_recipe.append(
QuantizationModifier(
targets=["Linear"],
ignore=["lm_head"],
scheme="FP8_DYNAMIC",
)
)
save_dir = MODEL_ID.split("/")[1] + "2of4-W8A8-FP8-Dynamic-Per-Token"
# check that asymmetric quantization is not being used
q_scheme = base_recipe[1].scheme
if not isinstance(q_scheme, str) and not q_scheme["weights"].symmetric:
raise ValueError(
"Asymmetric quantization with 2of4 sparsity is not supported by vLLM. "
"Please use symmetric quantization"
)
return base_recipe, save_dir
# Parse arguments
args = parse_args()
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Load and preprocess dataset
ds = (
load_dataset(DATASET_ID, split=DATASET_SPLIT)
.shuffle(seed=42)
.select(range(NUM_CALIBRATION_SAMPLES))
)
ds = ds.map(preprocess)
ds = ds.map(tokenize, remove_columns=ds.column_names)
# Get compression recipe and save directory
recipe, save_dir = get_recipe(args.fp8)
# Apply compression
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)
# Validate the compressed model
print("\n========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n")
# Save compressed model and tokenizer
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)