-
Notifications
You must be signed in to change notification settings - Fork 453
Expand file tree
/
Copy pathmistral3_example.py
More file actions
89 lines (76 loc) · 2.72 KB
/
mistral3_example.py
File metadata and controls
89 lines (76 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
import os
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
# Load model.
model_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
model = Mistral3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# Use a custom calibration chat template, rather than the overly-verbose default
file_path = os.path.join(os.path.dirname(__file__), "mistral3_chat_template.json")
with open(file_path, "r") as file:
processor.chat_template = json.load(file)["chat_template"]
# Oneshot arguments
DATASET_ID = "flickr30k"
DATASET_SPLIT = "test"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {
key: torch.tensor(value)
if key != "pixel_values"
else torch.tensor(value, dtype=model.dtype)
for key, value in batch[0].items()
}
# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
sequential_targets=["MistralDecoderLayer"],
ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
),
]
# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator,
)
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) # fix dtype
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")
# Save to disk compressed.
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)