1+ import  torch 
12from  fastdup .sentry  import  fastdup_capture_exception 
23from  fastdup .definitions  import  MISSING_LABEL 
34from  fastdup .galleries  import  fastdup_imread 
45import  cv2 
6+ from  tqdm  import  tqdm 
57
8+ device_to_captioner  =  {}
9+ 
10+ def  init_captioning (model_name = 'automatic' , device = 'cpu' , batch_size = 8 , max_new_tokens = 20 ,
11+                         use_float_16 = True ):
612
7- def  generate_labels (filenames , model_name = 'automatic' , device  =  'cpu' , batch_size = 8 ):
813    ''' 
914    This function generates captions for a given set of images, and takes the following arguments: 
1015        - filenames: the list of images passed to the function 
@@ -15,64 +20,82 @@ def generate_labels(filenames, model_name='automatic', device = 'cpu', batch_siz
1520            - BLIP: 'blip' 
1621        - batch_size: the size of image batches to caption (default: 8) 
1722        - device: whether to use a GPU (default: -1, CPU only ; set to 0 for GPU) 
23+         - max_bew_tokens: set the number of allowed tokens 
1824    ''' 
25+ 
26+     global  device_to_captioner 
1927    # use GPU if device is specified 
2028    if  device  ==  'gpu' :
2129        device  =  0 
2230    elif  device  ==  'cpu' :
2331        device  =  - 1 
32+         use_float_16  =  False 
2433    else :
25-         assert  False , "Incompatible device name entered. Available device names are gpu and cpu." 
34+         assert  False , "Incompatible device name entered {device} . Available device names are gpu and cpu." 
2635
2736    # confirm necessary dependencies are installed, and import them 
2837    try :
2938        from  transformers  import  pipeline 
3039        from  transformers .utils  import  logging 
31-         logging .set_verbosity_info ()
32-         import  torch 
33-         from  PIL  import  Image 
34-         from  tqdm  import  tqdm 
40+         logging .set_verbosity (50 )
41+ 
3542    except  Exception  as  e :
3643        fastdup_capture_exception ("Auto generate labels" , e )
3744        print ("Auto captioning requires an installation of the following libraries:\n " )
38-         print ("   huggingface transformers\n    pytorch\n    pillow \n    tqdm \n " )
39-         print ("to install, use `pip  install transformers torch pillow tqdm `" )
40-         return  [ MISSING_LABEL ]  *   len ( filenames ) 
45+         print ("   huggingface transformers\n    pytorch\n " )
46+         print ("    to install, use `pip3  install transformers torch`" )
47+         raise 
4148
4249    # dictionary of captioning models 
4350    models  =  {
4451        'automatic' : "nlpconnect/vit-gpt2-image-captioning" ,
4552        'vitgpt2' : "nlpconnect/vit-gpt2-image-captioning" ,
46-         'blip2 ' : "Salesforce/blip2-opt-2.7b" ,
53+         'blip-2 ' : "Salesforce/blip2-opt-2.7b" ,
4754        'blip' : "Salesforce/blip-image-captioning-large" 
4855    }
49- 
56+      assert   model_name   in   models . keys (),  f"Unknown captioning model  { model_name }  allowed models are  { models . keys () } " 
5057    model  =  models [model_name ]
58+     has_gpu  =  torch .cuda .is_available ()
59+     captioner  =  pipeline ("image-to-text" , model = model , device = device  if  has_gpu  else  "cpu" , max_new_tokens = max_new_tokens ,
60+                          torch_dtype = torch .float16  if  use_float_16  else  torch .float32 )
61+     device_to_captioner [device ] =  captioner 
5162
52-     # generate captions 
53-     try :
54-         captioner  =  pipeline ("image-to-text" , model = model , device = device )
55- 
56-         captions  =  []
57- 
58-         for  pred  in  captioner (filenames , batch_size = batch_size ):
59-             #caption = pred['generated_text'] 
60-             caption  =  '' .join ([d ['generated_text' ] for  d  in  pred ])
61-             captions .append (caption )
63+     return  captioner 
6264
65+ def  generate_labels (filenames , model_name = 'automatic' , device  =  'cpu' , batch_size = 8 , max_new_tokens = 20 , use_float_16 = True ):
66+     global  device_to_captioner 
67+     if  device  not  in   device_to_captioner :
68+         captioner  =  init_captioning (model_name , device , batch_size , max_new_tokens , use_float_16 )
69+     else :
70+         captioner  =  device_to_captioner [device ]
6371
64-         '''for image_path in tqdm(filenames): 
65-             img = Image.open(image_path) 
66-             pred = captioner(img) 
67-             caption = pred[0]['generated_text'] 
68-             captions.append(caption)''' 
69-         return  captions 
70- 
72+     captions  =  []
73+     # generate captions 
74+     try :
75+         for  i  in  tqdm (range (0 , len (filenames ), batch_size )):
76+             chunk  =  filenames [i :i  +  batch_size ]
77+             try :
78+                 for  pred  in  captioner (chunk , batch_size = batch_size ):
79+                     charstring  =  ''  if  model_name  !=  'blip'  else  ' ' 
80+                     caption  =  charstring .join ([d ['generated_text' ] for  d  in  pred ])
81+                     # Split the sentence into words 
82+                     words  =  caption .split ()
83+                     # Filter out words containing '#' 
84+                     filtered_words  =  [word  for  word  in  words  if  '#'  not  in   word ]
85+                     # Join the filtered words back into a sentence 
86+                     caption  =  ' ' .join (filtered_words )
87+                     caption  =  caption .strip ()
88+                     captions .append (caption )
89+             except  Exception  as  ex :
90+                 print ("Failed to caption chunk" , chunk [:5 ], ex )
91+                 captions .extend ([MISSING_LABEL ] *  len (chunk ))
7192
7293    except  Exception  as  e :
7394        fastdup_capture_exception ("Auto caption image" , e )
7495        return  [MISSING_LABEL ] *  len (filenames )
7596
97+     return  captions 
98+ 
7699
77100def  generate_vqa_labels (filenames , text , kwargs ):
78101    # confirm necessary dependencies are installed, and import them 
@@ -156,3 +179,15 @@ def generate_age_labels(filenames, kwargs):
156179        fastdup_capture_exception ("Age label" , e )
157180        return  [MISSING_LABEL ] *  len (filenames )
158181
182+ if  __name__  ==  "__main__" :
183+     import  fastdup 
184+     from  fastdup .captions  import  generate_labels 
185+     file  =  "/Users/dannybickson/visual_database/cxx/unittests/two_images/" 
186+     import  os 
187+     files  =  os .listdir (file )
188+     files  =  [os .path .join (file , f ) for  f  in  files ]
189+     ret  =  generate_labels (files , model_name = 'blip' )
190+     assert (len (ret ) ==  2 )
191+     print (ret )
192+     for  r  in  ret :
193+         assert  "shelf"  in  r  or  "shelves"  in  r  or  "store"  in  r 
0 commit comments