Skip to content

Commit 4d630df

Browse files
kelkelchengkylesayrsdsikkaved1betarahul-tuli
authored
[Tracing] Support tracing of Gemma3 [#1248] (#1373)
SUMMARY: Add support for tracing of Gemma3: [issue#1248](#1248). Steps that I have done: 1. Create gemma3.py from HF and update __init__.py. 2. Classes and functions that I modified: 2.1 Gemma3ForConditionalGeneration: _update_causal_mask and forward 2.2 Gemma3TextModel: _update_causal_mask, forward, and _prepare_4d_causal_attention_mask_with_cache_position TEST PLAN: Ran: `llmcompressor.trace --model_id google/gemma-3-4b-it --model_class TraceableGemma3ForConditionalGeneration --ignore "lm_head" "re:vision_tower.*" --modality vision` Output: <img width="796" alt="trace_output" src="https://github.com/user-attachments/assets/8f5c9c7d-32a9-4b12-b4b2-10b6a4352846" /> This is my first attempt at solving this issue. It is a fun learning experience and please review it carefully. Gemma3 can go through tracing now, but we might need further tests for the quantization as well. --------- Signed-off-by: Kelvin Cheng <kelvincheng216@gmail.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Signed-off-by: Domenic Barbuzzi <dbarbuzz@redhat.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: Vedant <146507396+ved1beta@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> Co-authored-by: Domenic Barbuzzi <dbarbuzz@redhat.com>
1 parent 7bc1881 commit 4d630df

File tree

3 files changed

+559
-1
lines changed

3 files changed

+559
-1
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import requests
2+
import torch
3+
from PIL import Image
4+
from transformers import AutoProcessor
5+
6+
from llmcompressor import oneshot
7+
from llmcompressor.modifiers.quantization import GPTQModifier
8+
from llmcompressor.transformers.tracing import TraceableGemma3ForConditionalGeneration
9+
10+
# Load model.
11+
model_id = "google/gemma-3-4b-it"
12+
model = TraceableGemma3ForConditionalGeneration.from_pretrained(
13+
model_id, device_map="auto", torch_dtype="auto"
14+
)
15+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
16+
17+
# Oneshot arguments
18+
DATASET_ID = "flickr30k"
19+
DATASET_SPLIT = {"calibration": "test[:512]"}
20+
NUM_CALIBRATION_SAMPLES = 512
21+
MAX_SEQUENCE_LENGTH = 2048
22+
23+
24+
# Define a oneshot data collator for multimodal inputs.
25+
def data_collator(batch):
26+
assert len(batch) == 1
27+
return {key: torch.tensor(value) for key, value in batch[0].items()}
28+
29+
30+
# Recipe
31+
recipe = [
32+
GPTQModifier(
33+
targets="Linear",
34+
scheme="W4A16",
35+
ignore=["re:*.lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
36+
),
37+
]
38+
39+
# Perform oneshot
40+
oneshot(
41+
model=model,
42+
tokenizer=model_id,
43+
dataset=DATASET_ID,
44+
splits=DATASET_SPLIT,
45+
recipe=recipe,
46+
max_seq_length=MAX_SEQUENCE_LENGTH,
47+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
48+
trust_remote_code_model=True,
49+
data_collator=data_collator,
50+
)
51+
52+
# Confirm generations of the quantized model look sane.
53+
print("========== SAMPLE GENERATION ==============")
54+
messages = [
55+
{
56+
"role": "user",
57+
"content": [
58+
{"type": "text", "text": "Please describe the animal in this image\n"},
59+
{"type": "image"},
60+
],
61+
},
62+
]
63+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
64+
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
65+
raw_image = Image.open(requests.get(image_url, stream=True).raw)
66+
67+
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
68+
output = model.generate(**inputs, max_new_tokens=100)
69+
print(processor.decode(output[0], skip_special_tokens=True))
70+
print("==========================================")
71+
72+
# Save to disk compressed.
73+
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
74+
model.save_pretrained(SAVE_DIR, save_compressed=True)
75+
processor.save_pretrained(SAVE_DIR)

src/llmcompressor/transformers/tracing/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from .gemma3 import (
2+
Gemma3ForConditionalGeneration as TraceableGemma3ForConditionalGeneration,
3+
)
14
from .llava import (
25
LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration,
36
)
@@ -11,12 +14,13 @@
1114
Idefics3ForConditionalGeneration as TraceableIdefics3ForConditionalGeneration,
1215
)
1316
from .qwen2_5_vl import (
14-
Qwen2_5_VLForConditionalGeneration as TraceableQwen2_5_VLForConditionalGeneration
17+
Qwen2_5_VLForConditionalGeneration as TraceableQwen2_5_VLForConditionalGeneration,
1518
)
1619
from .debug import get_model_class
1720

1821
__all__ = [
1922
"get_model_class",
23+
"TraceableGemma3ForConditionalGeneration",
2024
"TraceableLlavaForConditionalGeneration",
2125
"TraceableMllamaForConditionalGeneration",
2226
"TraceableQwen2VLForConditionalGeneration",

0 commit comments

Comments
 (0)