@@ -546,21 +546,22 @@ def should_reload_model(self, force_patch_weights=False):
546546 return False
547547
548548 def model_unload (self , memory_to_free = None , unpatch_weights = True ):
549+ model_loaded_size = self .model .loaded_size ()
550+ if memory_to_free is None :
551+ # free the full model
552+ memory_to_free = model_loaded_size
553+
549554 logging .debug (f"model_unload: { self .model .model .__class__ .__name__ } " )
550555 logging .debug (f"memory_to_free: { memory_to_free / (1024 * 1024 * 1024 )} GB" )
551556 logging .debug (f"unpatch_weights: { unpatch_weights } " )
552- logging .debug (f"loaded_size: { self . model . loaded_size () / (1024 * 1024 * 1024 )} GB" )
557+ logging .debug (f"loaded_size: { model_loaded_size / (1024 * 1024 * 1024 )} GB" )
553558 logging .debug (f"offload_device: { self .model .offload_device } " )
554559
555- if memory_to_free is None :
556- # free the full model
557- memory_to_free = self .model .loaded_size ()
558-
559560 available_memory = get_free_memory (self .model .offload_device )
560561 logging .debug (f"before unload, available_memory of offload device { self .model .offload_device } : { available_memory / (1024 * 1024 * 1024 )} GB" )
561562
562563 mmap_mem_threshold = get_mmap_mem_threshold_gb () * 1024 * 1024 * 1024 # this is reserved memory for other system usage
563- if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self . model . loaded_size () :
564+ if min ( memory_to_free , model_loaded_size ) > available_memory - mmap_mem_threshold or memory_to_free < model_loaded_size :
564565 partially_unload = True
565566 else :
566567 partially_unload = False
@@ -571,6 +572,8 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True):
571572 logging .debug (f"partially_unload freed vram: { freed / (1024 * 1024 * 1024 )} GB" )
572573 if freed < memory_to_free :
573574 logging .warning (f"Partially unload not enough memory, freed { freed / (1024 * 1024 * 1024 )} GB, memory_to_free { memory_to_free / (1024 * 1024 * 1024 )} GB" )
575+ if freed == model_loaded_size :
576+ partially_unload = False
574577 else :
575578 logging .debug ("Do full unload" )
576579 self .model .detach (unpatch_weights )
0 commit comments