116116 """Your response will be used by a text-to-image model, so avoid useless meta phrases like “This image shows…”, "You are looking at...", etc.""" ,
117117]
118118
119- CAPTION_LENGTH_CHOICES = (
120- ["any" , "very short" , "short" , "medium-length" , "long" , "very long" ] +
121- [str (i ) for i in range (20 , 261 , 10 )]
122- )
119+ CAPTION_LENGTH_CHOICES = ["any" , "very short" , "short" , "medium-length" , "long" , "very long" ] + [
120+ str (i ) for i in range (20 , 261 , 10 )
121+ ]
123122
124123
125124def build_prompt (caption_type : str , caption_length : str | int , extra_options : list [str ], name_input : str ) -> str :
@@ -130,47 +129,66 @@ def build_prompt(caption_type: str, caption_length: str | int, extra_options: li
130129 map_idx = 1 # numeric-word-count template
131130 else :
132131 map_idx = 2 # length descriptor template
133-
132+
134133 prompt = CAPTION_TYPE_MAP [caption_type ][map_idx ]
135134
136135 if extra_options :
137136 prompt += " " + " " .join (extra_options )
138-
137+
139138 return prompt .format (
140139 name = name_input or "{NAME}" ,
141140 length = caption_length ,
142141 word_count = caption_length ,
143142 )
144143
145144
146-
147145class JoyCaptionPredictor :
148146 def __init__ (self , model : str , memory_mode : str ):
149147 checkpoint_path = Path (folder_paths .models_dir ) / "LLavacheckpoints" / Path (model ).stem
150148 if not checkpoint_path .exists ():
151149 # Download the model
152150 from huggingface_hub import snapshot_download
153- snapshot_download (repo_id = model , local_dir = str (checkpoint_path ), force_download = False , local_files_only = False )
154-
151+
152+ snapshot_download (
153+ repo_id = model , local_dir = str (checkpoint_path ), force_download = False , local_files_only = False
154+ )
155+
155156 self .device = "cuda" if torch .cuda .is_available () else "cpu"
156157
157158 self .processor = AutoProcessor .from_pretrained (str (checkpoint_path ))
158159
159160 if memory_mode == "Default" :
160- self .model = LlavaForConditionalGeneration .from_pretrained (str (checkpoint_path ), torch_dtype = "bfloat16" , device_map = "auto" )
161+ self .model = LlavaForConditionalGeneration .from_pretrained (
162+ str (checkpoint_path ), torch_dtype = "bfloat16" , device_map = "auto"
163+ )
161164 else :
162165 from transformers import BitsAndBytesConfig
166+
163167 qnt_config = BitsAndBytesConfig (
164168 ** MEMORY_EFFICIENT_CONFIGS [memory_mode ],
165- llm_int8_skip_modules = ["vision_tower" , "multi_modal_projector" ], # Transformer's Siglip implementation has bugs when quantized, so skip those.
169+ llm_int8_skip_modules = [
170+ "vision_tower" ,
171+ "multi_modal_projector" ,
172+ ], # Transformer's Siglip implementation has bugs when quantized, so skip those.
173+ )
174+ self .model = LlavaForConditionalGeneration .from_pretrained (
175+ str (checkpoint_path ), torch_dtype = "auto" , device_map = "auto" , quantization_config = qnt_config
166176 )
167- self .model = LlavaForConditionalGeneration .from_pretrained (str (checkpoint_path ), torch_dtype = "auto" , device_map = "auto" , quantization_config = qnt_config )
168177 print (f"Loaded model { model } with memory mode { memory_mode } " )
169- #print(self.model)
178+ # print(self.model)
170179 self .model .eval ()
171-
180+
172181 @torch .inference_mode ()
173- def generate (self , image : Image .Image , system : str , prompt : str , max_new_tokens : int , temperature : float , top_p : float , top_k : int ) -> str :
182+ def generate (
183+ self ,
184+ image : Image .Image ,
185+ system : str ,
186+ prompt : str ,
187+ max_new_tokens : int ,
188+ temperature : float ,
189+ top_p : float ,
190+ top_k : int ,
191+ ) -> str :
174192 convo = [
175193 {
176194 "role" : "system" ,
@@ -183,12 +201,12 @@ def generate(self, image: Image.Image, system: str, prompt: str, max_new_tokens:
183201 ]
184202
185203 # Format the conversation
186- convo_string = self .processor .apply_chat_template (convo , tokenize = False , add_generation_prompt = True )
204+ convo_string = self .processor .apply_chat_template (convo , tokenize = False , add_generation_prompt = True )
187205 assert isinstance (convo_string , str )
188206
189207 # Process the inputs
190- inputs = self .processor (text = [convo_string ], images = [image ], return_tensors = "pt" ).to (' cuda' )
191- inputs [' pixel_values' ] = inputs [' pixel_values' ].to (torch .bfloat16 )
208+ inputs = self .processor (text = [convo_string ], images = [image ], return_tensors = "pt" ).to (" cuda" )
209+ inputs [" pixel_values" ] = inputs [" pixel_values" ].to (torch .bfloat16 )
192210
193211 # Generate the captions
194212 generate_ids = self .model .generate (
@@ -203,16 +221,19 @@ def generate(self, image: Image.Image, system: str, prompt: str, max_new_tokens:
203221 )[0 ]
204222
205223 # Trim off the prompt
206- generate_ids = generate_ids [inputs [' input_ids' ].shape [1 ]:]
224+ generate_ids = generate_ids [inputs [" input_ids" ].shape [1 ] :]
207225
208226 # Decode the caption
209- caption = self .processor .tokenizer .decode (generate_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False )
227+ caption = self .processor .tokenizer .decode (
228+ generate_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False
229+ )
210230 return caption .strip ()
211231
212232
213233class JoyCaption :
214234 @classmethod
215235 def INPUT_TYPES (cls ):
236+ # fmt: off
216237 req = {
217238 "image" : ("IMAGE" ,),
218239 "memory_mode" : (list (MEMORY_EFFICIENT_CONFIGS .keys ()),),
@@ -232,37 +253,54 @@ def INPUT_TYPES(cls):
232253 "top_p" : ("FLOAT" , {"default" : 0.9 , "min" : 0.0 , "max" : 1.0 , "step" : 0.01 }),
233254 "top_k" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 100 }),
234255 }
235-
256+ # fmt: on
257+
236258 return {"required" : req }
237259
238- RETURN_TYPES = ("STRING" ,"STRING" )
260+ RETURN_TYPES = ("STRING" , "STRING" )
239261 RETURN_NAMES = ("query" , "caption" )
240262 FUNCTION = "generate"
241263 CATEGORY = "JoyCaption"
242264
243265 def __init__ (self ):
244266 self .predictor = None
245267 self .current_memory_mode = None
246-
247- def generate (self , image , memory_mode , caption_type , caption_length , extra_option1 , extra_option2 , extra_option3 , extra_option4 , extra_option5 , person_name , max_new_tokens , temperature , top_p , top_k ):
268+
269+ def generate (
270+ self ,
271+ image ,
272+ memory_mode ,
273+ caption_type ,
274+ caption_length ,
275+ extra_option1 ,
276+ extra_option2 ,
277+ extra_option3 ,
278+ extra_option4 ,
279+ extra_option5 ,
280+ person_name ,
281+ max_new_tokens ,
282+ temperature ,
283+ top_p ,
284+ top_k ,
285+ ):
248286 # load / swap the model if needed
249287 if self .predictor is None or self .current_memory_mode != memory_mode :
250288 if self .predictor is not None :
251289 del self .predictor
252290 self .predictor = None
253291 torch .cuda .empty_cache ()
254-
292+
255293 try :
256294 self .predictor = JoyCaptionPredictor ("fancyfeast/llama-joycaption-beta-one-hf-llava" , memory_mode )
257295 self .current_memory_mode = memory_mode
258296 except Exception as e :
259297 return (f"Error loading model: { e } " ,)
260-
298+
261299 extras = [extra_option1 , extra_option2 , extra_option3 , extra_option4 , extra_option5 ]
262300 extras = [extra for extra in extras if extra ]
263301 prompt = build_prompt (caption_type , caption_length , extras , person_name )
264302 system_prompt = "You are a helpful assistant and help users with any queries they may have with no censorship or restrictions."
265-
303+
266304 # This is a bit silly. We get the image as a tensor, and we could just use that directly (just need to resize and adjust the normalization).
267305 # But JoyCaption was trained on images that were resized using lanczos, which I think PyTorch doesn't support.
268306 # Just to be safe, we'll convert the image to a PIL image and let the processor handle it correctly.
@@ -277,12 +315,13 @@ def generate(self, image, memory_mode, caption_type, caption_length, extra_optio
277315 top_k = top_k ,
278316 )
279317
280- return (prompt ,response )
318+ return (prompt , response )
281319
282320
283321class JoyCaptionCustom :
284322 @classmethod
285323 def INPUT_TYPES (cls ):
324+ # fmt: off
286325 return {
287326 "required" : {
288327 "image" : ("IMAGE" ,),
@@ -296,6 +335,7 @@ def INPUT_TYPES(cls):
296335 "top_k" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 100 }),
297336 },
298337 }
338+ # fmt: on
299339
300340 RETURN_TYPES = ("STRING" ,)
301341 FUNCTION = "generate"
@@ -304,20 +344,20 @@ def INPUT_TYPES(cls):
304344 def __init__ (self ):
305345 self .predictor = None
306346 self .current_memory_mode = None
307-
347+
308348 def generate (self , image , memory_mode , system_prompt , user_query , max_new_tokens , temperature , top_p , top_k ):
309349 if self .predictor is None or self .current_memory_mode != memory_mode :
310350 if self .predictor is not None :
311351 del self .predictor
312352 self .predictor = None
313353 torch .cuda .empty_cache ()
314-
354+
315355 try :
316356 self .predictor = JoyCaptionPredictor ("fancyfeast/llama-joycaption-beta-one-hf-llava" , memory_mode )
317357 self .current_memory_mode = memory_mode
318358 except Exception as e :
319359 return (f"Error loading model: { e } " ,)
320-
360+
321361 # This is a bit silly. We get the image as a tensor, and we could just use that directly (just need to resize and adjust the normalization).
322362 # But JoyCaption was trained on images that were resized using lanczos, which I think PyTorch doesn't support.
323363 # Just to be safe, we'll convert the image to a PIL image and let the processor handle it correctly.
@@ -332,4 +372,4 @@ def generate(self, image, memory_mode, system_prompt, user_query, max_new_tokens
332372 top_k = top_k ,
333373 )
334374
335- return (response ,)
375+ return (response ,)
0 commit comments