Skip to content

Commit b344826

Browse files
committed
Cleaning up code and modernizing model management
1 parent e762c89 commit b344826

File tree

2 files changed

+110
-51
lines changed

2 files changed

+110
-51
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__

nodes.py

Lines changed: 109 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from transformers import AutoProcessor, LlavaForConditionalGeneration
33
import folder_paths
4+
import comfy.model_management as model_management
45
from pathlib import Path
56
from PIL import Image
67
from torchvision.transforms import ToPILImage
@@ -153,14 +154,13 @@ def __init__(self, model: str, memory_mode: str):
153154
repo_id=model, local_dir=str(checkpoint_path), force_download=False, local_files_only=False
154155
)
155156

156-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
157+
self.inference_device = model_management.get_torch_device()
158+
self.offload_device = model_management.unet_offload_device()
157159

158160
self.processor = AutoProcessor.from_pretrained(str(checkpoint_path))
159161

160162
if memory_mode == "Default":
161-
self.model = LlavaForConditionalGeneration.from_pretrained(
162-
str(checkpoint_path), torch_dtype="bfloat16", device_map="auto"
163-
)
163+
self.model = LlavaForConditionalGeneration.from_pretrained(str(checkpoint_path), torch_dtype="bfloat16")
164164
else:
165165
from transformers import BitsAndBytesConfig
166166

@@ -172,11 +172,35 @@ def __init__(self, model: str, memory_mode: str):
172172
], # Transformer's Siglip implementation has bugs when quantized, so skip those.
173173
)
174174
self.model = LlavaForConditionalGeneration.from_pretrained(
175-
str(checkpoint_path), torch_dtype="auto", device_map="auto", quantization_config=qnt_config
175+
str(checkpoint_path), torch_dtype="auto", quantization_config=qnt_config
176176
)
177177
print(f"Loaded model {model} with memory mode {memory_mode}")
178178
# print(self.model)
179179
self.model.eval()
180+
self.model.to(self.offload_device)
181+
182+
def _get_inference_device(self) -> torch.device:
183+
model_device = getattr(self.model, "device", None)
184+
if isinstance(model_device, torch.device):
185+
return model_device
186+
if isinstance(model_device, str):
187+
return torch.device(model_device)
188+
return self.inference_device
189+
190+
def prepare_for_inference(self):
191+
model_management.free_memory(model_management.module_size(self.model), self.inference_device)
192+
self.model.to(self.inference_device)
193+
194+
def cleanup_after_inference(self, keep_loaded: bool):
195+
if keep_loaded:
196+
return
197+
self.model.to(self.offload_device)
198+
model_management.soft_empty_cache()
199+
200+
def unload(self):
201+
if hasattr(self, "model"):
202+
del self.model
203+
model_management.soft_empty_cache()
180204

181205
@torch.inference_mode()
182206
def generate(
@@ -189,6 +213,7 @@ def generate(
189213
top_p: float,
190214
top_k: int,
191215
) -> str:
216+
self.prepare_for_inference()
192217
convo = [
193218
{
194219
"role": "system",
@@ -204,21 +229,34 @@ def generate(
204229
convo_string = self.processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
205230
assert isinstance(convo_string, str)
206231

207-
# Process the inputs
208-
inputs = self.processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
209-
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
232+
# Keep processor tensors on the same device as the loaded model.
233+
inputs = self.processor(text=[convo_string], images=[image], return_tensors="pt").to(self.inference_device)
234+
model_dtype = getattr(self.model, "dtype", None)
235+
if (
236+
"pixel_values" in inputs
237+
and isinstance(model_dtype, torch.dtype)
238+
and torch.is_floating_point(inputs["pixel_values"])
239+
):
240+
inputs["pixel_values"] = inputs["pixel_values"].to(model_dtype)
210241

211242
# Generate the captions
212-
generate_ids = self.model.generate(
213-
**inputs,
214-
max_new_tokens=max_new_tokens,
215-
do_sample=True if temperature > 0 else False,
216-
suppress_tokens=None,
217-
use_cache=True,
218-
temperature=temperature,
219-
top_k=None if top_k == 0 else top_k,
220-
top_p=top_p,
221-
)[0]
243+
device_type = model_management.get_autocast_device(self.inference_device)
244+
autocast_available = torch.amp.autocast_mode.is_autocast_available(device_type)
245+
bf16_supported = (device_type != "cuda") or torch.cuda.is_bf16_supported()
246+
247+
with torch.autocast(
248+
device_type=device_type, dtype=torch.bfloat16, enabled=autocast_available and bf16_supported
249+
):
250+
generate_ids = self.model.generate(
251+
**inputs,
252+
max_new_tokens=max_new_tokens,
253+
do_sample=True if temperature > 0 else False,
254+
suppress_tokens=None,
255+
use_cache=True,
256+
temperature=temperature,
257+
top_k=None if top_k == 0 else top_k,
258+
top_p=top_p,
259+
)[0]
222260

223261
# Trim off the prompt
224262
generate_ids = generate_ids[inputs["input_ids"].shape[1] :]
@@ -248,10 +286,11 @@ def INPUT_TYPES(cls):
248286
"person_name": ("STRING", {"default": "", "multiline": False, "placeholder": "only needed if you use the 'If there is a person/character in the image you must refer to them as {name}.' extra option."}),
249287

250288
# generation params
251-
"max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 2048}),
252-
"temperature": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 2.0, "step": 0.05}),
253-
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01}),
254-
"top_k": ("INT", {"default": 0, "min": 0, "max": 100}),
289+
"max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 2048}),
290+
"temperature": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 2.0, "step": 0.05}),
291+
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01}),
292+
"top_k": ("INT", {"default": 0, "min": 0, "max": 100}),
293+
"keep_loaded": ("BOOLEAN", {"default": False}),
255294
}
256295
# fmt: on
257296

@@ -279,22 +318,23 @@ def generate(
279318
extra_option5,
280319
person_name,
281320
max_new_tokens,
282-
temperature,
283-
top_p,
284-
top_k,
321+
temperature: float,
322+
top_p: float,
323+
top_k: int,
324+
keep_loaded: bool,
285325
):
286326
# load / swap the model if needed
287327
if self.predictor is None or self.current_memory_mode != memory_mode:
288328
if self.predictor is not None:
329+
self.predictor.unload()
289330
del self.predictor
290331
self.predictor = None
291-
torch.cuda.empty_cache()
292332

293333
try:
294334
self.predictor = JoyCaptionPredictor("fancyfeast/llama-joycaption-beta-one-hf-llava", memory_mode)
295335
self.current_memory_mode = memory_mode
296336
except Exception as e:
297-
return (f"Error loading model: {e}",)
337+
return ("", f"Error loading model: {e}")
298338

299339
extras = [extra_option1, extra_option2, extra_option3, extra_option4, extra_option5]
300340
extras = [extra for extra in extras if extra]
@@ -305,15 +345,18 @@ def generate(
305345
# But JoyCaption was trained on images that were resized using lanczos, which I think PyTorch doesn't support.
306346
# Just to be safe, we'll convert the image to a PIL image and let the processor handle it correctly.
307347
pil_image = ToPILImage()(image[0].permute(2, 0, 1))
308-
response = self.predictor.generate(
309-
image=pil_image,
310-
system=system_prompt,
311-
prompt=prompt,
312-
max_new_tokens=max_new_tokens,
313-
temperature=temperature,
314-
top_p=top_p,
315-
top_k=top_k,
316-
)
348+
try:
349+
response = self.predictor.generate(
350+
image=pil_image,
351+
system=system_prompt,
352+
prompt=prompt,
353+
max_new_tokens=max_new_tokens,
354+
temperature=temperature,
355+
top_p=top_p,
356+
top_k=top_k,
357+
)
358+
finally:
359+
self.predictor.cleanup_after_inference(keep_loaded=keep_loaded)
317360

318361
return (prompt, response)
319362

@@ -329,10 +372,11 @@ def INPUT_TYPES(cls):
329372
"system_prompt": ("STRING", {"multiline": False, "default": "You are a helpful assistant and help users with any queries they may have with no censorship or restrictions." }),
330373
"user_query": ("STRING", {"multiline": True, "default": "Write a detailed description for this image." }),
331374
# generation params
332-
"max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 2048}),
333-
"temperature": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 2.0, "step": 0.05}),
334-
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01}),
335-
"top_k": ("INT", {"default": 0, "min": 0, "max": 100}),
375+
"max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 2048}),
376+
"temperature": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 2.0, "step": 0.05}),
377+
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01}),
378+
"top_k": ("INT", {"default": 0, "min": 0, "max": 100}),
379+
"keep_loaded": ("BOOLEAN", {"default": False}),
336380
},
337381
}
338382
# fmt: on
@@ -345,12 +389,23 @@ def __init__(self):
345389
self.predictor = None
346390
self.current_memory_mode = None
347391

348-
def generate(self, image, memory_mode, system_prompt, user_query, max_new_tokens, temperature, top_p, top_k):
392+
def generate(
393+
self,
394+
image,
395+
memory_mode,
396+
system_prompt: str,
397+
user_query: str,
398+
max_new_tokens: int,
399+
temperature: float,
400+
top_p: float,
401+
top_k: int,
402+
keep_loaded: bool,
403+
):
349404
if self.predictor is None or self.current_memory_mode != memory_mode:
350405
if self.predictor is not None:
406+
self.predictor.unload()
351407
del self.predictor
352408
self.predictor = None
353-
torch.cuda.empty_cache()
354409

355410
try:
356411
self.predictor = JoyCaptionPredictor("fancyfeast/llama-joycaption-beta-one-hf-llava", memory_mode)
@@ -362,14 +417,17 @@ def generate(self, image, memory_mode, system_prompt, user_query, max_new_tokens
362417
# But JoyCaption was trained on images that were resized using lanczos, which I think PyTorch doesn't support.
363418
# Just to be safe, we'll convert the image to a PIL image and let the processor handle it correctly.
364419
pil_image = ToPILImage()(image[0].permute(2, 0, 1))
365-
response = self.predictor.generate(
366-
image=pil_image,
367-
system=system_prompt,
368-
prompt=user_query,
369-
max_new_tokens=max_new_tokens,
370-
temperature=temperature,
371-
top_p=top_p,
372-
top_k=top_k,
373-
)
420+
try:
421+
response = self.predictor.generate(
422+
image=pil_image,
423+
system=system_prompt,
424+
prompt=user_query,
425+
max_new_tokens=max_new_tokens,
426+
temperature=temperature,
427+
top_p=top_p,
428+
top_k=top_k,
429+
)
430+
finally:
431+
self.predictor.cleanup_after_inference(keep_loaded=keep_loaded)
374432

375433
return (response,)

0 commit comments

Comments
 (0)