Skip to content

Commit ccb6cde

Browse files
committed
Fix model offloading when in quantiziation modes
1 parent b344826 commit ccb6cde

File tree

1 file changed

+65
-22
lines changed

1 file changed

+65
-22
lines changed

nodes.py

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -154,52 +154,85 @@ def __init__(self, model: str, memory_mode: str):
154154
repo_id=model, local_dir=str(checkpoint_path), force_download=False, local_files_only=False
155155
)
156156

157+
self.checkpoint_path = str(checkpoint_path)
158+
self.memory_mode = memory_mode
159+
157160
self.inference_device = model_management.get_torch_device()
158161
self.offload_device = model_management.unet_offload_device()
159162

160163
self.processor = AutoProcessor.from_pretrained(str(checkpoint_path))
161164

162-
if memory_mode == "Default":
163-
self.model = LlavaForConditionalGeneration.from_pretrained(str(checkpoint_path), torch_dtype="bfloat16")
165+
self.model = None
166+
self.model_size_bytes = None
167+
self.is_kbit = self.memory_mode != "Default"
168+
169+
def _load_model(self):
170+
# In normal mode:
171+
# We load the model, free memory on the offload device, and then move it to the offload device.
172+
# In quantized modes:
173+
# The model must be loaded directory to the inference device.
174+
# This function is only called during inference.
175+
# After inference, if we need to offload, we just unload the model entirely.
176+
# It'll be rebuilt during the next inference.
177+
# We free memory on the inference device if we know how big the model is from a previous load.
178+
if self.memory_mode == "Default":
179+
self.model = LlavaForConditionalGeneration.from_pretrained(self.checkpoint_path, torch_dtype="bfloat16")
180+
self.model_size_bytes = model_management.module_size(self.model)
181+
model_management.free_memory(self.model_size_bytes, self.offload_device)
182+
self.model.to(self.offload_device)
164183
else:
165184
from transformers import BitsAndBytesConfig
166185

186+
if self.model_size_bytes is not None:
187+
model_management.free_memory(self.model_size_bytes, self.inference_device)
188+
167189
qnt_config = BitsAndBytesConfig(
168-
**MEMORY_EFFICIENT_CONFIGS[memory_mode],
190+
**MEMORY_EFFICIENT_CONFIGS[self.memory_mode],
169191
llm_int8_skip_modules=[
170192
"vision_tower",
171193
"multi_modal_projector",
172194
], # Transformer's Siglip implementation has bugs when quantized, so skip those.
173195
)
196+
174197
self.model = LlavaForConditionalGeneration.from_pretrained(
175-
str(checkpoint_path), torch_dtype="auto", quantization_config=qnt_config
198+
self.checkpoint_path,
199+
torch_dtype="auto",
200+
quantization_config=qnt_config,
201+
device_map=_cuda_device_map(self.inference_device),
176202
)
177-
print(f"Loaded model {model} with memory mode {memory_mode}")
178-
# print(self.model)
203+
self.model_size_bytes = model_management.module_size(self.model)
204+
179205
self.model.eval()
180-
self.model.to(self.offload_device)
181206

182-
def _get_inference_device(self) -> torch.device:
183-
model_device = getattr(self.model, "device", None)
184-
if isinstance(model_device, torch.device):
185-
return model_device
186-
if isinstance(model_device, str):
187-
return torch.device(model_device)
188-
return self.inference_device
207+
print(f"Loaded model (mode={self.memory_mode}, kbit={self.is_kbit})")
189208

190209
def prepare_for_inference(self):
191-
model_management.free_memory(model_management.module_size(self.model), self.inference_device)
210+
if self.model is None:
211+
self._load_model()
212+
213+
if self.is_kbit:
214+
return
215+
216+
model_management.free_memory(self.model_size_bytes, self.inference_device)
192217
self.model.to(self.inference_device)
193218

194219
def cleanup_after_inference(self, keep_loaded: bool):
195220
if keep_loaded:
196221
return
222+
if self.model is None:
223+
return
224+
225+
if self.is_kbit:
226+
self.unload()
227+
return
228+
197229
self.model.to(self.offload_device)
198230
model_management.soft_empty_cache()
199231

200232
def unload(self):
201-
if hasattr(self, "model"):
233+
if self.model is not None:
202234
del self.model
235+
self.model = None
203236
model_management.soft_empty_cache()
204237

205238
@torch.inference_mode()
@@ -213,7 +246,9 @@ def generate(
213246
top_p: float,
214247
top_k: int,
215248
) -> str:
249+
# Load the model if it isn't already loaded and move it to the inference device if needed.
216250
self.prepare_for_inference()
251+
217252
convo = [
218253
{
219254
"role": "system",
@@ -303,7 +338,6 @@ def INPUT_TYPES(cls):
303338

304339
def __init__(self):
305340
self.predictor = None
306-
self.current_memory_mode = None
307341

308342
def generate(
309343
self,
@@ -323,16 +357,18 @@ def generate(
323357
top_k: int,
324358
keep_loaded: bool,
325359
):
360+
if image.shape[0] != 1:
361+
return ("", "Error: batch size greater than 1 is not supported.")
362+
326363
# load / swap the model if needed
327-
if self.predictor is None or self.current_memory_mode != memory_mode:
364+
if self.predictor is None or self.predictor.memory_mode != memory_mode:
328365
if self.predictor is not None:
329366
self.predictor.unload()
330367
del self.predictor
331368
self.predictor = None
332369

333370
try:
334371
self.predictor = JoyCaptionPredictor("fancyfeast/llama-joycaption-beta-one-hf-llava", memory_mode)
335-
self.current_memory_mode = memory_mode
336372
except Exception as e:
337373
return ("", f"Error loading model: {e}")
338374

@@ -387,7 +423,6 @@ def INPUT_TYPES(cls):
387423

388424
def __init__(self):
389425
self.predictor = None
390-
self.current_memory_mode = None
391426

392427
def generate(
393428
self,
@@ -401,15 +436,17 @@ def generate(
401436
top_k: int,
402437
keep_loaded: bool,
403438
):
404-
if self.predictor is None or self.current_memory_mode != memory_mode:
439+
if image.shape[0] != 1:
440+
return ("Error: batch size greater than 1 is not supported.",)
441+
442+
if self.predictor is None or self.predictor.memory_mode != memory_mode:
405443
if self.predictor is not None:
406444
self.predictor.unload()
407445
del self.predictor
408446
self.predictor = None
409447

410448
try:
411449
self.predictor = JoyCaptionPredictor("fancyfeast/llama-joycaption-beta-one-hf-llava", memory_mode)
412-
self.current_memory_mode = memory_mode
413450
except Exception as e:
414451
return (f"Error loading model: {e}",)
415452

@@ -431,3 +468,9 @@ def generate(
431468
self.predictor.cleanup_after_inference(keep_loaded=keep_loaded)
432469

433470
return (response,)
471+
472+
473+
def _cuda_device_map(dev: torch.device):
474+
if dev.type == "cuda":
475+
return {"": (dev.index or 0)}
476+
return {"": str(dev)}

0 commit comments

Comments
 (0)