11import requests
2- import torch
32from PIL import Image
43from transformers import AutoProcessor , Gemma3ForConditionalGeneration
4+ from transformers .data import DataCollatorWithPadding
55
66from llmcompressor import oneshot
77from llmcompressor .modifiers .quantization import GPTQModifier
1111model_id = "google/gemma-3-4b-it"
1212model = Gemma3ForConditionalGeneration .from_pretrained (model_id , torch_dtype = "auto" )
1313processor = AutoProcessor .from_pretrained (model_id , trust_remote_code = True )
14+ collator = DataCollatorWithPadding (processor .tokenizer )
1415
1516# Oneshot arguments
1617DATASET_ID = "flickr30k"
1920MAX_SEQUENCE_LENGTH = 2048
2021
2122
22- # Define a oneshot data collator for multimodal inputs.
23- def data_collator ( batch ):
24- assert len ( batch ) == 1
25- return { key : torch . tensor ( value ) for key , value in batch [ 0 ]. items ()}
23+ def data_collator ( features : list [ dict [ str , object ]]):
24+ # remove extra dim added by vision processor
25+ features = [{ key : feature [ key ][ 0 ] for key in feature } for feature in features ]
26+ return collator ( features )
2627
2728
2829# Recipe
@@ -48,7 +49,8 @@ def data_collator(batch):
4849 max_seq_length = MAX_SEQUENCE_LENGTH ,
4950 num_calibration_samples = NUM_CALIBRATION_SAMPLES ,
5051 trust_remote_code_model = True ,
51- # data_collator=data_collator,
52+ batch_size = 4 ,
53+ data_collator = data_collator ,
5254)
5355
5456# Confirm generations of the quantized model look sane.
0 commit comments