11import torch
22from transformers import AutoProcessor , LlavaForConditionalGeneration
33import folder_paths
4+ import comfy .model_management as model_management
45from pathlib import Path
56from PIL import Image
67from torchvision .transforms import ToPILImage
@@ -153,14 +154,13 @@ def __init__(self, model: str, memory_mode: str):
153154 repo_id = model , local_dir = str (checkpoint_path ), force_download = False , local_files_only = False
154155 )
155156
156- self .device = "cuda" if torch .cuda .is_available () else "cpu"
157+ self .inference_device = model_management .get_torch_device ()
158+ self .offload_device = model_management .unet_offload_device ()
157159
158160 self .processor = AutoProcessor .from_pretrained (str (checkpoint_path ))
159161
160162 if memory_mode == "Default" :
161- self .model = LlavaForConditionalGeneration .from_pretrained (
162- str (checkpoint_path ), torch_dtype = "bfloat16" , device_map = "auto"
163- )
163+ self .model = LlavaForConditionalGeneration .from_pretrained (str (checkpoint_path ), torch_dtype = "bfloat16" )
164164 else :
165165 from transformers import BitsAndBytesConfig
166166
@@ -172,11 +172,35 @@ def __init__(self, model: str, memory_mode: str):
172172 ], # Transformer's Siglip implementation has bugs when quantized, so skip those.
173173 )
174174 self .model = LlavaForConditionalGeneration .from_pretrained (
175- str (checkpoint_path ), torch_dtype = "auto" , device_map = "auto" , quantization_config = qnt_config
175+ str (checkpoint_path ), torch_dtype = "auto" , quantization_config = qnt_config
176176 )
177177 print (f"Loaded model { model } with memory mode { memory_mode } " )
178178 # print(self.model)
179179 self .model .eval ()
180+ self .model .to (self .offload_device )
181+
182+ def _get_inference_device (self ) -> torch .device :
183+ model_device = getattr (self .model , "device" , None )
184+ if isinstance (model_device , torch .device ):
185+ return model_device
186+ if isinstance (model_device , str ):
187+ return torch .device (model_device )
188+ return self .inference_device
189+
190+ def prepare_for_inference (self ):
191+ model_management .free_memory (model_management .module_size (self .model ), self .inference_device )
192+ self .model .to (self .inference_device )
193+
194+ def cleanup_after_inference (self , keep_loaded : bool ):
195+ if keep_loaded :
196+ return
197+ self .model .to (self .offload_device )
198+ model_management .soft_empty_cache ()
199+
200+ def unload (self ):
201+ if hasattr (self , "model" ):
202+ del self .model
203+ model_management .soft_empty_cache ()
180204
181205 @torch .inference_mode ()
182206 def generate (
@@ -189,6 +213,7 @@ def generate(
189213 top_p : float ,
190214 top_k : int ,
191215 ) -> str :
216+ self .prepare_for_inference ()
192217 convo = [
193218 {
194219 "role" : "system" ,
@@ -204,21 +229,34 @@ def generate(
204229 convo_string = self .processor .apply_chat_template (convo , tokenize = False , add_generation_prompt = True )
205230 assert isinstance (convo_string , str )
206231
207- # Process the inputs
208- inputs = self .processor (text = [convo_string ], images = [image ], return_tensors = "pt" ).to ("cuda" )
209- inputs ["pixel_values" ] = inputs ["pixel_values" ].to (torch .bfloat16 )
232+ # Keep processor tensors on the same device as the loaded model.
233+ inputs = self .processor (text = [convo_string ], images = [image ], return_tensors = "pt" ).to (self .inference_device )
234+ model_dtype = getattr (self .model , "dtype" , None )
235+ if (
236+ "pixel_values" in inputs
237+ and isinstance (model_dtype , torch .dtype )
238+ and torch .is_floating_point (inputs ["pixel_values" ])
239+ ):
240+ inputs ["pixel_values" ] = inputs ["pixel_values" ].to (model_dtype )
210241
211242 # Generate the captions
212- generate_ids = self .model .generate (
213- ** inputs ,
214- max_new_tokens = max_new_tokens ,
215- do_sample = True if temperature > 0 else False ,
216- suppress_tokens = None ,
217- use_cache = True ,
218- temperature = temperature ,
219- top_k = None if top_k == 0 else top_k ,
220- top_p = top_p ,
221- )[0 ]
243+ device_type = model_management .get_autocast_device (self .inference_device )
244+ autocast_available = torch .amp .autocast_mode .is_autocast_available (device_type )
245+ bf16_supported = (device_type != "cuda" ) or torch .cuda .is_bf16_supported ()
246+
247+ with torch .autocast (
248+ device_type = device_type , dtype = torch .bfloat16 , enabled = autocast_available and bf16_supported
249+ ):
250+ generate_ids = self .model .generate (
251+ ** inputs ,
252+ max_new_tokens = max_new_tokens ,
253+ do_sample = True if temperature > 0 else False ,
254+ suppress_tokens = None ,
255+ use_cache = True ,
256+ temperature = temperature ,
257+ top_k = None if top_k == 0 else top_k ,
258+ top_p = top_p ,
259+ )[0 ]
222260
223261 # Trim off the prompt
224262 generate_ids = generate_ids [inputs ["input_ids" ].shape [1 ] :]
@@ -248,10 +286,11 @@ def INPUT_TYPES(cls):
248286 "person_name" : ("STRING" , {"default" : "" , "multiline" : False , "placeholder" : "only needed if you use the 'If there is a person/character in the image you must refer to them as {name}.' extra option." }),
249287
250288 # generation params
251- "max_new_tokens" : ("INT" , {"default" : 512 , "min" : 1 , "max" : 2048 }),
252- "temperature" : ("FLOAT" , {"default" : 0.6 , "min" : 0.0 , "max" : 2.0 , "step" : 0.05 }),
253- "top_p" : ("FLOAT" , {"default" : 0.9 , "min" : 0.0 , "max" : 1.0 , "step" : 0.01 }),
254- "top_k" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 100 }),
289+ "max_new_tokens" : ("INT" , {"default" : 512 , "min" : 1 , "max" : 2048 }),
290+ "temperature" : ("FLOAT" , {"default" : 0.6 , "min" : 0.0 , "max" : 2.0 , "step" : 0.05 }),
291+ "top_p" : ("FLOAT" , {"default" : 0.9 , "min" : 0.0 , "max" : 1.0 , "step" : 0.01 }),
292+ "top_k" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 100 }),
293+ "keep_loaded" : ("BOOLEAN" , {"default" : False }),
255294 }
256295 # fmt: on
257296
@@ -279,22 +318,23 @@ def generate(
279318 extra_option5 ,
280319 person_name ,
281320 max_new_tokens ,
282- temperature ,
283- top_p ,
284- top_k ,
321+ temperature : float ,
322+ top_p : float ,
323+ top_k : int ,
324+ keep_loaded : bool ,
285325 ):
286326 # load / swap the model if needed
287327 if self .predictor is None or self .current_memory_mode != memory_mode :
288328 if self .predictor is not None :
329+ self .predictor .unload ()
289330 del self .predictor
290331 self .predictor = None
291- torch .cuda .empty_cache ()
292332
293333 try :
294334 self .predictor = JoyCaptionPredictor ("fancyfeast/llama-joycaption-beta-one-hf-llava" , memory_mode )
295335 self .current_memory_mode = memory_mode
296336 except Exception as e :
297- return (f"Error loading model: { e } " , )
337+ return ("" , f"Error loading model: { e } " )
298338
299339 extras = [extra_option1 , extra_option2 , extra_option3 , extra_option4 , extra_option5 ]
300340 extras = [extra for extra in extras if extra ]
@@ -305,15 +345,18 @@ def generate(
305345 # But JoyCaption was trained on images that were resized using lanczos, which I think PyTorch doesn't support.
306346 # Just to be safe, we'll convert the image to a PIL image and let the processor handle it correctly.
307347 pil_image = ToPILImage ()(image [0 ].permute (2 , 0 , 1 ))
308- response = self .predictor .generate (
309- image = pil_image ,
310- system = system_prompt ,
311- prompt = prompt ,
312- max_new_tokens = max_new_tokens ,
313- temperature = temperature ,
314- top_p = top_p ,
315- top_k = top_k ,
316- )
348+ try :
349+ response = self .predictor .generate (
350+ image = pil_image ,
351+ system = system_prompt ,
352+ prompt = prompt ,
353+ max_new_tokens = max_new_tokens ,
354+ temperature = temperature ,
355+ top_p = top_p ,
356+ top_k = top_k ,
357+ )
358+ finally :
359+ self .predictor .cleanup_after_inference (keep_loaded = keep_loaded )
317360
318361 return (prompt , response )
319362
@@ -329,10 +372,11 @@ def INPUT_TYPES(cls):
329372 "system_prompt" : ("STRING" , {"multiline" : False , "default" : "You are a helpful assistant and help users with any queries they may have with no censorship or restrictions." }),
330373 "user_query" : ("STRING" , {"multiline" : True , "default" : "Write a detailed description for this image." }),
331374 # generation params
332- "max_new_tokens" : ("INT" , {"default" : 512 , "min" : 1 , "max" : 2048 }),
333- "temperature" : ("FLOAT" , {"default" : 0.6 , "min" : 0.0 , "max" : 2.0 , "step" : 0.05 }),
334- "top_p" : ("FLOAT" , {"default" : 0.9 , "min" : 0.0 , "max" : 1.0 , "step" : 0.01 }),
335- "top_k" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 100 }),
375+ "max_new_tokens" : ("INT" , {"default" : 512 , "min" : 1 , "max" : 2048 }),
376+ "temperature" : ("FLOAT" , {"default" : 0.6 , "min" : 0.0 , "max" : 2.0 , "step" : 0.05 }),
377+ "top_p" : ("FLOAT" , {"default" : 0.9 , "min" : 0.0 , "max" : 1.0 , "step" : 0.01 }),
378+ "top_k" : ("INT" , {"default" : 0 , "min" : 0 , "max" : 100 }),
379+ "keep_loaded" : ("BOOLEAN" , {"default" : False }),
336380 },
337381 }
338382 # fmt: on
@@ -345,12 +389,23 @@ def __init__(self):
345389 self .predictor = None
346390 self .current_memory_mode = None
347391
348- def generate (self , image , memory_mode , system_prompt , user_query , max_new_tokens , temperature , top_p , top_k ):
392+ def generate (
393+ self ,
394+ image ,
395+ memory_mode ,
396+ system_prompt : str ,
397+ user_query : str ,
398+ max_new_tokens : int ,
399+ temperature : float ,
400+ top_p : float ,
401+ top_k : int ,
402+ keep_loaded : bool ,
403+ ):
349404 if self .predictor is None or self .current_memory_mode != memory_mode :
350405 if self .predictor is not None :
406+ self .predictor .unload ()
351407 del self .predictor
352408 self .predictor = None
353- torch .cuda .empty_cache ()
354409
355410 try :
356411 self .predictor = JoyCaptionPredictor ("fancyfeast/llama-joycaption-beta-one-hf-llava" , memory_mode )
@@ -362,14 +417,17 @@ def generate(self, image, memory_mode, system_prompt, user_query, max_new_tokens
362417 # But JoyCaption was trained on images that were resized using lanczos, which I think PyTorch doesn't support.
363418 # Just to be safe, we'll convert the image to a PIL image and let the processor handle it correctly.
364419 pil_image = ToPILImage ()(image [0 ].permute (2 , 0 , 1 ))
365- response = self .predictor .generate (
366- image = pil_image ,
367- system = system_prompt ,
368- prompt = user_query ,
369- max_new_tokens = max_new_tokens ,
370- temperature = temperature ,
371- top_p = top_p ,
372- top_k = top_k ,
373- )
420+ try :
421+ response = self .predictor .generate (
422+ image = pil_image ,
423+ system = system_prompt ,
424+ prompt = user_query ,
425+ max_new_tokens = max_new_tokens ,
426+ temperature = temperature ,
427+ top_p = top_p ,
428+ top_k = top_k ,
429+ )
430+ finally :
431+ self .predictor .cleanup_after_inference (keep_loaded = keep_loaded )
374432
375433 return (response ,)
0 commit comments