@@ -154,52 +154,85 @@ def __init__(self, model: str, memory_mode: str):
154154 repo_id = model , local_dir = str (checkpoint_path ), force_download = False , local_files_only = False
155155 )
156156
157+ self .checkpoint_path = str (checkpoint_path )
158+ self .memory_mode = memory_mode
159+
157160 self .inference_device = model_management .get_torch_device ()
158161 self .offload_device = model_management .unet_offload_device ()
159162
160163 self .processor = AutoProcessor .from_pretrained (str (checkpoint_path ))
161164
162- if memory_mode == "Default" :
163- self .model = LlavaForConditionalGeneration .from_pretrained (str (checkpoint_path ), torch_dtype = "bfloat16" )
165+ self .model = None
166+ self .model_size_bytes = None
167+ self .is_kbit = self .memory_mode != "Default"
168+
169+ def _load_model (self ):
170+ # In normal mode:
171+ # We load the model, free memory on the offload device, and then move it to the offload device.
172+ # In quantized modes:
173+ # The model must be loaded directory to the inference device.
174+ # This function is only called during inference.
175+ # After inference, if we need to offload, we just unload the model entirely.
176+ # It'll be rebuilt during the next inference.
177+ # We free memory on the inference device if we know how big the model is from a previous load.
178+ if self .memory_mode == "Default" :
179+ self .model = LlavaForConditionalGeneration .from_pretrained (self .checkpoint_path , torch_dtype = "bfloat16" )
180+ self .model_size_bytes = model_management .module_size (self .model )
181+ model_management .free_memory (self .model_size_bytes , self .offload_device )
182+ self .model .to (self .offload_device )
164183 else :
165184 from transformers import BitsAndBytesConfig
166185
186+ if self .model_size_bytes is not None :
187+ model_management .free_memory (self .model_size_bytes , self .inference_device )
188+
167189 qnt_config = BitsAndBytesConfig (
168- ** MEMORY_EFFICIENT_CONFIGS [memory_mode ],
190+ ** MEMORY_EFFICIENT_CONFIGS [self . memory_mode ],
169191 llm_int8_skip_modules = [
170192 "vision_tower" ,
171193 "multi_modal_projector" ,
172194 ], # Transformer's Siglip implementation has bugs when quantized, so skip those.
173195 )
196+
174197 self .model = LlavaForConditionalGeneration .from_pretrained (
175- str (checkpoint_path ), torch_dtype = "auto" , quantization_config = qnt_config
198+ self .checkpoint_path ,
199+ torch_dtype = "auto" ,
200+ quantization_config = qnt_config ,
201+ device_map = _cuda_device_map (self .inference_device ),
176202 )
177- print ( f"Loaded model { model } with memory mode { memory_mode } " )
178- # print(self.model)
203+ self . model_size_bytes = model_management . module_size ( self . model )
204+
179205 self .model .eval ()
180- self .model .to (self .offload_device )
181206
182- def _get_inference_device (self ) -> torch .device :
183- model_device = getattr (self .model , "device" , None )
184- if isinstance (model_device , torch .device ):
185- return model_device
186- if isinstance (model_device , str ):
187- return torch .device (model_device )
188- return self .inference_device
207+ print (f"Loaded model (mode={ self .memory_mode } , kbit={ self .is_kbit } )" )
189208
190209 def prepare_for_inference (self ):
191- model_management .free_memory (model_management .module_size (self .model ), self .inference_device )
210+ if self .model is None :
211+ self ._load_model ()
212+
213+ if self .is_kbit :
214+ return
215+
216+ model_management .free_memory (self .model_size_bytes , self .inference_device )
192217 self .model .to (self .inference_device )
193218
194219 def cleanup_after_inference (self , keep_loaded : bool ):
195220 if keep_loaded :
196221 return
222+ if self .model is None :
223+ return
224+
225+ if self .is_kbit :
226+ self .unload ()
227+ return
228+
197229 self .model .to (self .offload_device )
198230 model_management .soft_empty_cache ()
199231
200232 def unload (self ):
201- if hasattr ( self , " model" ) :
233+ if self . model is not None :
202234 del self .model
235+ self .model = None
203236 model_management .soft_empty_cache ()
204237
205238 @torch .inference_mode ()
@@ -213,7 +246,9 @@ def generate(
213246 top_p : float ,
214247 top_k : int ,
215248 ) -> str :
249+ # Load the model if it isn't already loaded and move it to the inference device if needed.
216250 self .prepare_for_inference ()
251+
217252 convo = [
218253 {
219254 "role" : "system" ,
@@ -303,7 +338,6 @@ def INPUT_TYPES(cls):
303338
304339 def __init__ (self ):
305340 self .predictor = None
306- self .current_memory_mode = None
307341
308342 def generate (
309343 self ,
@@ -323,16 +357,18 @@ def generate(
323357 top_k : int ,
324358 keep_loaded : bool ,
325359 ):
360+ if image .shape [0 ] != 1 :
361+ return ("" , "Error: batch size greater than 1 is not supported." )
362+
326363 # load / swap the model if needed
327- if self .predictor is None or self .current_memory_mode != memory_mode :
364+ if self .predictor is None or self .predictor . memory_mode != memory_mode :
328365 if self .predictor is not None :
329366 self .predictor .unload ()
330367 del self .predictor
331368 self .predictor = None
332369
333370 try :
334371 self .predictor = JoyCaptionPredictor ("fancyfeast/llama-joycaption-beta-one-hf-llava" , memory_mode )
335- self .current_memory_mode = memory_mode
336372 except Exception as e :
337373 return ("" , f"Error loading model: { e } " )
338374
@@ -387,7 +423,6 @@ def INPUT_TYPES(cls):
387423
388424 def __init__ (self ):
389425 self .predictor = None
390- self .current_memory_mode = None
391426
392427 def generate (
393428 self ,
@@ -401,15 +436,17 @@ def generate(
401436 top_k : int ,
402437 keep_loaded : bool ,
403438 ):
404- if self .predictor is None or self .current_memory_mode != memory_mode :
439+ if image .shape [0 ] != 1 :
440+ return ("Error: batch size greater than 1 is not supported." ,)
441+
442+ if self .predictor is None or self .predictor .memory_mode != memory_mode :
405443 if self .predictor is not None :
406444 self .predictor .unload ()
407445 del self .predictor
408446 self .predictor = None
409447
410448 try :
411449 self .predictor = JoyCaptionPredictor ("fancyfeast/llama-joycaption-beta-one-hf-llava" , memory_mode )
412- self .current_memory_mode = memory_mode
413450 except Exception as e :
414451 return (f"Error loading model: { e } " ,)
415452
@@ -431,3 +468,9 @@ def generate(
431468 self .predictor .cleanup_after_inference (keep_loaded = keep_loaded )
432469
433470 return (response ,)
471+
472+
473+ def _cuda_device_map (dev : torch .device ):
474+ if dev .type == "cuda" :
475+ return {"" : (dev .index or 0 )}
476+ return {"" : str (dev )}
0 commit comments