Skip to content

Commit 32de48f

Browse files
committed
works for vision
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 29bf737 commit 32de48f

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

examples/multimodal_vision/gemma3_example.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import requests
2-
import torch
32
from PIL import Image
43
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
4+
from transformers.data import DataCollatorWithPadding
55

66
from llmcompressor import oneshot
77
from llmcompressor.modifiers.quantization import GPTQModifier
@@ -11,6 +11,7 @@
1111
model_id = "google/gemma-3-4b-it"
1212
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
1313
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
14+
collator = DataCollatorWithPadding(processor.tokenizer)
1415

1516
# Oneshot arguments
1617
DATASET_ID = "flickr30k"
@@ -19,10 +20,10 @@
1920
MAX_SEQUENCE_LENGTH = 2048
2021

2122

22-
# Define a oneshot data collator for multimodal inputs.
23-
def data_collator(batch):
24-
assert len(batch) == 1
25-
return {key: torch.tensor(value) for key, value in batch[0].items()}
23+
def data_collator(features: list[dict[str, object]]):
24+
# remove extra dim added by vision processor
25+
features = [{key: feature[key][0] for key in feature} for feature in features]
26+
return collator(features)
2627

2728

2829
# Recipe
@@ -48,7 +49,8 @@ def data_collator(batch):
4849
max_seq_length=MAX_SEQUENCE_LENGTH,
4950
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
5051
trust_remote_code_model=True,
51-
# data_collator=data_collator,
52+
batch_size=4,
53+
data_collator=data_collator,
5254
)
5355

5456
# Confirm generations of the quantized model look sane.

0 commit comments

Comments
 (0)