Skip to content

Commit 999d660

Browse files
authored
[Audio] Add whisper fp8 dynamic example (#1111)
## Purpose ## * Add example of quantizing multimodal model with FP8 dynamic ## Changes ## * Add whisper FP8 example, collaborated with @mgoin ## Testing ## * Ran example to mention Signed-off-by: Kyle Sayers <[email protected]>
1 parent 7fc4a67 commit 999d660

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from datasets import load_dataset
2+
from transformers import AutoProcessor, WhisperForConditionalGeneration
3+
4+
from llmcompressor.modifiers.quantization import QuantizationModifier
5+
from llmcompressor.transformers import oneshot
6+
7+
MODEL_ID = "openai/whisper-large-v2"
8+
9+
# Load model.
10+
model = WhisperForConditionalGeneration.from_pretrained(
11+
MODEL_ID, device_map="auto", torch_dtype="auto"
12+
)
13+
model.config.forced_decoder_ids = None
14+
processor = AutoProcessor.from_pretrained(MODEL_ID)
15+
processor.tokenizer.set_prefix_tokens(language="en", task="transcribe")
16+
17+
# Configure the quantization algorithm and scheme.
18+
# In this case, we:
19+
# * quantize the weights to fp8 with per channel via ptq
20+
# * quantize the activations to fp8 with dynamic per token
21+
recipe = QuantizationModifier(
22+
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]
23+
)
24+
25+
# Apply quantization.
26+
oneshot(model=model, recipe=recipe)
27+
28+
# Confirm generations of the quantized model look sane.
29+
print("========== SAMPLE GENERATION ==============")
30+
ds = load_dataset(
31+
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]"
32+
)
33+
sample = ds[0]["audio"]
34+
input_features = processor(
35+
sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
36+
).input_features
37+
input_features = input_features.to(model.device)
38+
predicted_ids = model.generate(input_features, language="en", forced_decoder_ids=None)
39+
print(processor.batch_decode(predicted_ids, skip_special_tokens=False)[0])
40+
# Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel
41+
print("==========================================")
42+
43+
# Save to disk in compressed-tensors format.
44+
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
45+
model.save_pretrained(SAVE_DIR, save_compressed=True)
46+
processor.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)