Skip to content

Commit 5b9df94

Browse files
committed
add example
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 370c04c commit 5b9df94

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from datasets import load_dataset
3+
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
4+
5+
from llmcompressor import oneshot
6+
from llmcompressor.modifiers.quantization import GPTQModifier
7+
8+
# Select model and load it.
9+
MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
10+
model = Qwen2AudioForConditionalGeneration.from_pretrained(MODEL_ID, dtype="auto")
11+
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
12+
13+
# Select calibration dataset.
14+
DATASET_ID = "MLCommons/peoples_speech"
15+
DATASET_SUBSET = "test"
16+
DATASET_SPLIT = "test"
17+
18+
# Select number of samples. 512 samples is a good place to start.
19+
# Increasing the number of samples can improve accuracy.
20+
NUM_CALIBRATION_SAMPLES = 64
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
# Load raw dataset for generation testing.
24+
raw_ds = load_dataset(
25+
DATASET_ID,
26+
DATASET_SUBSET,
27+
split=f"{DATASET_SPLIT}[:1]",
28+
)
29+
30+
# Load dataset for calibration.
31+
ds = load_dataset(
32+
DATASET_ID,
33+
DATASET_SUBSET,
34+
split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]",
35+
)
36+
37+
38+
def preprocess(example):
39+
# Qwen2Audio uses a chat template format
40+
messages = [
41+
{
42+
"role": "user",
43+
"content": [
44+
{"type": "audio", "audio_url": "placeholder"},
45+
],
46+
},
47+
{
48+
"role": "user",
49+
"content": [
50+
{"type": "text", "text": "What did the person say?"},
51+
],
52+
},
53+
{
54+
"role": "assistant",
55+
"content": [
56+
{"type": "text", "text": example["text"]},
57+
],
58+
},
59+
]
60+
61+
# Apply chat template
62+
text = processor.apply_chat_template(
63+
messages, tokenize=False, add_generation_prompt=False
64+
)
65+
66+
# Process using the processor (it handles audio token expansion)
67+
inputs = processor(
68+
text=text,
69+
audio=[example["audio"]["array"]],
70+
sampling_rate=example["audio"]["sampling_rate"],
71+
return_tensors="pt",
72+
)
73+
74+
# Strip batch dimension and return
75+
return {key: value[0] for key, value in inputs.items()}
76+
77+
78+
ds = ds.map(preprocess, remove_columns=ds.column_names)
79+
80+
# Recipe
81+
recipe = GPTQModifier(
82+
targets="Linear",
83+
scheme="W4A16",
84+
ignore=["lm_head", "re:audio_tower.*"],
85+
)
86+
87+
# Apply algorithms.
88+
oneshot(
89+
model=model,
90+
dataset=ds,
91+
recipe=recipe,
92+
max_seq_length=MAX_SEQUENCE_LENGTH,
93+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
94+
)
95+
96+
# Confirm generations of the model before quantization.
97+
print("========== SAMPLE GENERATION ==============")
98+
dispatch_model(model)
99+
raw_sample = raw_ds[0]
100+
conversation = [
101+
{
102+
"role": "user",
103+
"content": [
104+
{"type": "audio", "audio_url": "placeholder"},
105+
],
106+
},
107+
{
108+
"role": "user",
109+
"content": [
110+
{"type": "text", "text": "What did the person say?"},
111+
],
112+
},
113+
]
114+
text_prompt = processor.apply_chat_template(
115+
conversation, tokenize=False, add_generation_prompt=True
116+
)
117+
inputs = processor(
118+
text=text_prompt,
119+
audio=[raw_sample["audio"]["array"]],
120+
sampling_rate=raw_sample["audio"]["sampling_rate"],
121+
return_tensors="pt",
122+
).to(model.device)
123+
124+
output = model.generate(**inputs, max_new_tokens=100)
125+
print(processor.batch_decode(output, skip_special_tokens=True)[0])
126+
print("==========================================\n\n")
127+
# that's where you have a lot of windows in the south no actually that's passive solar
128+
# and passive solar is something that was developed and designed in the 1960s and 70s
129+
# and it was a great thing for what it was at the time but it's not a passive house
130+
131+
# Save to disk compressed.
132+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128"
133+
model.save_pretrained(SAVE_DIR, save_compressed=True)
134+
processor.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)