1-
21from fastdup .sentry import fastdup_capture_exception
32from fastdup .definitions import MISSING_LABEL
43from fastdup .galleries import fastdup_imread
5- from tqdm import tqdm
64import cv2
75
8- def generate_labels (filenames , kwargs ):
9- try :
10- from transformers import VisionEncoderDecoderModel , ViTImageProcessor , AutoTokenizer
11- import torch
12- except Exception as e :
13- fastdup_capture_exception ("Auto generate labels" , e )
14- print ("For auto captioning images need to install transforms and torch packages using `pip install transformers torch`" )
15- return [MISSING_LABEL ]* len (filenames )
16-
17- try :
18- from PIL import Image
19- model = VisionEncoderDecoderModel .from_pretrained ("nlpconnect/vit-gpt2-image-captioning" )
20- feature_extractor = ViTImageProcessor .from_pretrained ("nlpconnect/vit-gpt2-image-captioning" )
21- tokenizer = AutoTokenizer .from_pretrained ("nlpconnect/vit-gpt2-image-captioning" )
22-
23- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
24- model .to (device )
25- max_length = 16
26- num_beams = 4
27- gen_kwargs = {"max_length" : max_length , "num_beams" : num_beams }
28-
29- images = []
30- for image_path in tqdm (filenames ):
31- i_image = fastdup_imread (image_path , None , kwargs = kwargs )
32- if i_image is not None :
33- i_image = cv2 .cvtColor (i_image , cv2 .COLOR_BGR2RGB )
34- im_pil = Image .fromarray (i_image )
35- images .append (im_pil )
36- else :
37- images .append (None )
38-
39- pixel_values = feature_extractor (images = images , return_tensors = "pt" ).pixel_values
40- pixel_values = pixel_values .to (device )
41- output_ids = model .generate (pixel_values , ** gen_kwargs )
42-
43- preds = tokenizer .batch_decode (output_ids , skip_special_tokens = True )
44- preds = [pred .strip () for pred in preds ]
45- return preds
46- except Exception as e :
47- fastdup_capture_exception ("Auto caption image" , e )
48- return [MISSING_LABEL ]* len (filenames )
496
50- def generate_blip_labels (filenames , kwargs ):
7+ def generate_labels (filenames , modelname = 'automatic' , batch_size = 8 ):
8+ '''
9+ This function generates captions for a given set of images, and takes the following arguments:
10+ - filenames: the list of images passed to the function
11+ - modelname: the captioning model to be used (default: vitgpt2)
12+ currently available models are:
13+ - ViT-GPT2 : 'vitgpt2'
14+ - BLIP-2: 'blip2'
15+ - BLIP: 'blip'
16+ - batch_size: the size of image batches to caption (default: 8)
17+ '''
5118
19+ # confirm necessary dependencies are installed, and import them
5220 try :
53- from transformers import BlipProcessor , BlipForConditionalGeneration
21+ from transformers import pipeline
22+ import torch
5423 from PIL import Image
24+ from tqdm import tqdm
5525 except Exception as e :
5626 fastdup_capture_exception ("Auto generate labels" , e )
57- print ("For auto captioning images need to install transforms and torch packages using `pip install transformers`" )
27+ print ("Auto captioning requires an installation of the following libraries:\n " )
28+ print (" huggingface transformers\n pytorch\n pillow\n tqdm\n " )
29+ print ("to install, use `pip install transformers torch pillow tqdm`" )
5830 return [MISSING_LABEL ] * len (filenames )
5931
60- try :
61- processor = BlipProcessor .from_pretrained ("Salesforce/blip-image-captioning-large" )
62- model = BlipForConditionalGeneration .from_pretrained ("Salesforce/blip-image-captioning-large" )
63- preds = []
64- for image_path in tqdm (filenames ):
65- i_image = fastdup_imread (image_path , None , kwargs = kwargs )
66- if i_image is not None :
67- i_image = cv2 .cvtColor (i_image , cv2 .COLOR_BGR2RGB )
68- im_pil = Image .fromarray (i_image )
69- inputs = processor (im_pil , return_tensors = "pt" )
70- out = model .generate (** inputs )
71- preds .append ((processor .decode (out [0 ], skip_special_tokens = True )))
72- else :
73- preds .append (MISSING_LABEL )
74- return preds
32+ # dictionary of captioning models
33+ models = {
34+ 'automatic' : "nlpconnect/vit-gpt2-image-captioning" ,
35+ 'vitgpt2' : "nlpconnect/vit-gpt2-image-captioning" ,
36+ 'blip2' : "Salesforce/blip2-opt-2.7b" ,
37+ 'blip' : "Salesforce/blip-image-captioning-large"
38+ }
7539
76- except Exception as e :
77- fastdup_capture_exception ("Auto caption image blip" , e )
78- return [MISSING_LABEL ]* len (filenames )
79-
80- def generate_blip2_labels (filenames , kwargs , text = None ):
81-
82- try :
83- from transformers import Blip2Processor , Blip2Model
84- from PIL import Image
85- import torch
86- except Exception as e :
87- fastdup_capture_exception ("Auto generate labels" , e )
88- print ("For auto captioning images need to install transforms and torch packages using `pip install transformers torch`" )
89- return [MISSING_LABEL ] * len (filenames )
40+ model = models [modelname ]
9041
42+ # generate captions
9143 try :
44+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
45+ captioner = pipeline ("image-to-text" , model = model , device = device , batch_size = batch_size )
9246
93- processor = Blip2Processor .from_pretrained ("Salesforce/blip2-opt-2.7b" )
94- model = Blip2Model .from_pretrained ("Salesforce/blip2-opt-2.7b" , torch_dtype = torch .float16 )
95- device = "cuda" if torch .cuda .is_available () else "cpu"
96- model .to (device )
97- preds = []
47+ captions = []
9848 for image_path in tqdm (filenames ):
99- i_image = fastdup_imread (image_path , None , kwargs = kwargs )
100- if i_image is not None :
101- i_image = cv2 .cvtColor (i_image , cv2 .COLOR_BGR2RGB )
102- im_pil = Image .fromarray (i_image )
103- inputs = processor (images = im_pil , text = text , return_tensors = "pt" ).to (device , torch .float16 )
104- generated_ids = model .generate (** inputs )
105- generated_text = processor .batch_decode (generated_ids , skip_special_tokens = True )[0 ].strip ()
106- preds .append (generated_text )
107- else :
108- preds .append (MISSING_LABEL )
109- return preds
110-
111- except Exception as e :
112- fastdup_capture_exception ("Auto caption image blip" , e )
113- return [MISSING_LABEL ]* len (filenames )
114-
115-
49+ img = Image .open (image_path )
50+ pred = captioner (img )
51+ caption = pred [0 ]['generated_text' ]
52+ captions .append (caption )
53+ return captions
11654
11755
56+ except Exception as e :
57+ fastdup_capture_exception ("Auto caption image" , e )
58+ return [MISSING_LABEL ] * len (filenames )
11859
11960
12061def generate_vqa_labels (filenames , text , kwargs ):
62+ # confirm necessary dependencies are installed, and import them
12163 try :
12264 from transformers import ViltProcessor , ViltForQuestionAnswering
65+ import torch
12366 from PIL import Image
67+ from tqdm import tqdm
12468 except Exception as e :
12569 fastdup_capture_exception ("Auto generate labels" , e )
126- print (
127- "For auto captioning images need to install transforms and torch packages using `pip install transformers`" )
70+ print ("Auto captioning requires an installation of the following libraries:\n " )
71+ print (" huggingface transformers\n pytorch\n pillow\n tqdm\n " )
72+ print ("to install, use `pip install transformers torch pillow tqdm`" )
12873 return [MISSING_LABEL ] * len (filenames )
12974
13075 try :
@@ -150,15 +95,26 @@ def generate_vqa_labels(filenames, text, kwargs):
15095
15196 except Exception as e :
15297 fastdup_capture_exception ("Auto caption image vqa" , e )
153- return [MISSING_LABEL ]* len (filenames )
98+ return [MISSING_LABEL ] * len (filenames )
15499
155100
156101def generate_age_labels (filenames , kwargs ):
157- from transformers import ViTFeatureExtractor , ViTForImageClassification
158- model = ViTForImageClassification .from_pretrained ('nateraw/vit-age-classifier' )
159- transforms = ViTFeatureExtractor .from_pretrained ('nateraw/vit-age-classifier' )
102+ # confirm necessary dependencies are installed, and import them
103+ try :
104+ from transformers import ViTFeatureExtractor , ViTForImageClassification
105+ import torch
106+ from PIL import Image
107+ from tqdm import tqdm
108+ except Exception as e :
109+ fastdup_capture_exception ("Auto generate labels" , e )
110+ print ("Auto captioning requires an installation of the following libraries:\n " )
111+ print (" huggingface transformers\n pytorch\n pillow\n tqdm\n " )
112+ print ("to install, use `pip install transformers torch pillow tqdm`" )
113+ return [MISSING_LABEL ] * len (filenames )
160114
161115 try :
116+ model = ViTForImageClassification .from_pretrained ('nateraw/vit-age-classifier' )
117+ transforms = ViTFeatureExtractor .from_pretrained ('nateraw/vit-age-classifier' )
162118 preds = []
163119 # Get example image from official fairface repo + read it in as an image
164120 for image_path in tqdm (filenames ):
@@ -174,8 +130,9 @@ def generate_age_labels(filenames, kwargs):
174130
175131 # Predicted Classes
176132 pred = int (proba .argmax (1 )[0 ].int ())
177- preds .append ( model .config .id2label [pred ])
133+ preds .append (model .config .id2label [pred ])
178134 return preds
179135 except Exception as e :
180136 fastdup_capture_exception ("Age label" , e )
181137 return [MISSING_LABEL ] * len (filenames )
138+
0 commit comments