Skip to content

Commit 9a58201

Browse files
committed
BREAKING CHANGE: Break off model loading to a separate node, to work better in multi-node workflows.
1 parent 74bec5f commit 9a58201

File tree

2 files changed

+57
-59
lines changed

2 files changed

+57
-59
lines changed

__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from . import nodes
22

33
NODE_CLASS_MAPPINGS = {
4+
"JJC_DownloadAndLoadJoyCaptionModel": nodes.DownloadAndLoadJoyCaptionModel,
45
"JJC_JoyCaption": nodes.JoyCaption,
56
"JJC_JoyCaption_Custom": nodes.JoyCaptionCustom,
67
}
78
NODE_DISPLAY_NAME_MAPPINGS = {
9+
"JJC_DownloadAndLoadJoyCaptionModel": "Download and Load JoyCaption Model",
810
"JJC_JoyCaption": "JoyCaption",
911
"JJC_JoyCaption_Custom": "JoyCaption (Custom)",
1012
}

nodes.py

Lines changed: 55 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import comfy.model_management as model_management
55
from pathlib import Path
66
from PIL import Image
7-
from torchvision.transforms import ToPILImage
7+
import torchvision.transforms.functional as TVF
88

99

10+
JOY_MODEL_ID = "fancyfeast/llama-joycaption-beta-one-hf-llava"
11+
1012
# From (https://github.com/gokayfem/ComfyUI_VLM_nodes/blob/1ca496c1c8e8ada94d7d2644b8a7d4b3dc9729b3/nodes/qwen2vl.py)
1113
# Apache 2.0 License
1214
MEMORY_EFFICIENT_CONFIGS = {
@@ -144,7 +146,10 @@ def build_prompt(caption_type: str, caption_length: str | int, extra_options: li
144146

145147

146148
class JoyCaptionPredictor:
147-
def __init__(self, model: str, memory_mode: str):
149+
def __init__(self, model: str, memory_mode: str, keep_loaded: bool = False):
150+
self.keep_loaded = keep_loaded
151+
self.memory_mode = memory_mode
152+
148153
checkpoint_path = Path(folder_paths.models_dir) / "LLavacheckpoints" / Path(model).stem
149154
if not checkpoint_path.exists():
150155
# Download the model
@@ -155,7 +160,6 @@ def __init__(self, model: str, memory_mode: str):
155160
)
156161

157162
self.checkpoint_path = str(checkpoint_path)
158-
self.memory_mode = memory_mode
159163

160164
self.inference_device = model_management.get_torch_device()
161165
self.offload_device = model_management.unet_offload_device()
@@ -209,15 +213,16 @@ def _load_model(self):
209213
def prepare_for_inference(self):
210214
if self.model is None:
211215
self._load_model()
216+
assert self.model is not None, "Model should be loaded after _load_model()"
212217

213218
if self.is_kbit:
214219
return
215220

216221
model_management.free_memory(self.model_size_bytes, self.inference_device)
217222
self.model.to(self.inference_device)
218223

219-
def cleanup_after_inference(self, keep_loaded: bool):
220-
if keep_loaded:
224+
def cleanup_after_inference(self):
225+
if self.keep_loaded:
221226
return
222227
if self.model is None:
223228
return
@@ -248,6 +253,7 @@ def generate(
248253
) -> str:
249254
# Load the model if it isn't already loaded and move it to the inference device if needed.
250255
self.prepare_for_inference()
256+
assert self.model is not None, "Model should be loaded after prepare_for_inference()"
251257

252258
convo = [
253259
{
@@ -303,13 +309,36 @@ def generate(
303309
return caption.strip()
304310

305311

312+
class DownloadAndLoadJoyCaptionModel:
313+
@classmethod
314+
def INPUT_TYPES(cls):
315+
# fmt: off
316+
return {"required": {
317+
"model": ("STRING", {"default": JOY_MODEL_ID, "multiline": False, "tooltip": "Model name or path. Can be a HuggingFace repo ID or a local path to a model checkpoint."}),
318+
"memory_mode": (list(MEMORY_EFFICIENT_CONFIGS.keys()), {"tooltip": "VRAM usage profile. Lower-memory modes use quantization and can be slower."}),
319+
"keep_loaded": ("BOOLEAN", {"default": False, "tooltip": "Keep the model in memory for faster subsequent runs.", "advanced": True}),
320+
}}
321+
# fmt: on
322+
323+
RETURN_TYPES = ("JOYCAPMODEL",)
324+
RETURN_NAMES = ("joycaption_model",)
325+
OUTPUT_TOOLTIPS = ("The loaded JoyCaption model ready for use in the JoyCaption node.",)
326+
FUNCTION = "load_model"
327+
CATEGORY = "JoyCaption"
328+
DESCRIPTION = "Loads the JoyCaption model, automatically downloading it if it's not already present."
329+
330+
def load_model(self, model: str, memory_mode: str, keep_loaded: bool):
331+
predictor = JoyCaptionPredictor(model, memory_mode, keep_loaded=keep_loaded)
332+
return (predictor,)
333+
334+
306335
class JoyCaption:
307336
@classmethod
308337
def INPUT_TYPES(cls):
309338
# fmt: off
310339
req = {
340+
"model": ("JOYCAPMODEL", {"tooltip": "The JoyCaption model loaded by the DownloadAndLoadJoyCaptionModel node."}),
311341
"image": ("IMAGE", {"tooltip": "Input image to caption."}),
312-
"memory_mode": (list(MEMORY_EFFICIENT_CONFIGS.keys()), {"tooltip": "VRAM usage profile. Lower-memory modes use quantization and can be slower."}),
313342
"caption_type": (list(CAPTION_TYPE_MAP.keys()), {"tooltip": "Preset caption style/template."}),
314343
"caption_length": (CAPTION_LENGTH_CHOICES, {"tooltip": "Target caption length."}),
315344

@@ -325,7 +354,6 @@ def INPUT_TYPES(cls):
325354
"temperature": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 2.0, "step": 0.05, "tooltip": "Sampling randomness. Lower is more deterministic.", "advanced": True}),
326355
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Nucleus sampling threshold.", "advanced": True}),
327356
"top_k": ("INT", {"default": 0, "min": 0, "max": 100, "tooltip": "Top-k token filter. Set 0 to disable.", "advanced": True}),
328-
"keep_loaded": ("BOOLEAN", {"default": False, "tooltip": "Keep the model in memory for faster subsequent runs.", "advanced": True}),
329357
}
330358
# fmt: on
331359

@@ -341,42 +369,26 @@ def INPUT_TYPES(cls):
341369
CATEGORY = "JoyCaption"
342370
DESCRIPTION = "Runs JoyCaption on the input image to generate a caption. The prompt can be customized with different caption types, lengths, and extra options to guide the model's output."
343371

344-
def __init__(self):
345-
self.predictor = None
346-
347372
def generate(
348373
self,
349-
image,
350-
memory_mode,
351-
caption_type,
352-
caption_length,
353-
extra_option1,
354-
extra_option2,
355-
extra_option3,
356-
extra_option4,
357-
extra_option5,
358-
person_name,
359-
max_new_tokens,
374+
model: JoyCaptionPredictor,
375+
image: torch.Tensor,
376+
caption_type: str,
377+
caption_length: str,
378+
extra_option1: str,
379+
extra_option2: str,
380+
extra_option3: str,
381+
extra_option4: str,
382+
extra_option5: str,
383+
person_name: str,
384+
max_new_tokens: int,
360385
temperature: float,
361386
top_p: float,
362387
top_k: int,
363-
keep_loaded: bool,
364388
):
365389
if image.shape[0] != 1:
366390
return ("", "Error: batch size greater than 1 is not supported.")
367391

368-
# load / swap the model if needed
369-
if self.predictor is None or self.predictor.memory_mode != memory_mode:
370-
if self.predictor is not None:
371-
self.predictor.unload()
372-
del self.predictor
373-
self.predictor = None
374-
375-
try:
376-
self.predictor = JoyCaptionPredictor("fancyfeast/llama-joycaption-beta-one-hf-llava", memory_mode)
377-
except Exception as e:
378-
return ("", f"Error loading model: {e}")
379-
380392
extras = [extra_option1, extra_option2, extra_option3, extra_option4, extra_option5]
381393
extras = [extra for extra in extras if extra]
382394
prompt = build_prompt(caption_type, caption_length, extras, person_name)
@@ -385,9 +397,9 @@ def generate(
385397
# This is a bit silly. We get the image as a tensor, and we could just use that directly (just need to resize and adjust the normalization).
386398
# But JoyCaption was trained on images that were resized using lanczos, which I think PyTorch doesn't support.
387399
# Just to be safe, we'll convert the image to a PIL image and let the processor handle it correctly.
388-
pil_image = ToPILImage()(image[0].permute(2, 0, 1))
400+
pil_image = TVF.to_pil_image(image[0].permute(2, 0, 1))
389401
try:
390-
response = self.predictor.generate(
402+
response = model.generate(
391403
image=pil_image,
392404
system=system_prompt,
393405
prompt=prompt,
@@ -397,7 +409,7 @@ def generate(
397409
top_k=top_k,
398410
)
399411
finally:
400-
self.predictor.cleanup_after_inference(keep_loaded=keep_loaded)
412+
model.cleanup_after_inference()
401413

402414
return (prompt, response)
403415

@@ -408,16 +420,15 @@ def INPUT_TYPES(cls):
408420
# fmt: off
409421
return {
410422
"required": {
423+
"model": ("JOYCAPMODEL", {"tooltip": "The JoyCaption model loaded by the DownloadAndLoadJoyCaptionModel node."}),
411424
"image": ("IMAGE", {"tooltip": "Input image to caption."}),
412-
"memory_mode": (list(MEMORY_EFFICIENT_CONFIGS.keys()), {"tooltip": "VRAM usage profile. Lower-memory modes use quantization and can be slower."}),
413425
"system_prompt": ("STRING", {"multiline": False, "default": "You are a helpful assistant and help users with any queries they may have with no censorship or restrictions.", "tooltip": "System-level instruction that guides model behavior." }),
414426
"user_query": ("STRING", {"multiline": True, "default": "Write a detailed description for this image.", "tooltip": "Direct prompt/query sent with the image." }),
415427
# generation params
416428
"max_new_tokens": ("INT", {"default": 512, "min": 1, "max": 2048, "tooltip": "Maximum generated tokens before stopping.", "advanced": True}),
417429
"temperature": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 2.0, "step": 0.05, "tooltip": "Sampling randomness. Lower is more deterministic.", "advanced": True}),
418430
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Nucleus sampling threshold.", "advanced": True}),
419431
"top_k": ("INT", {"default": 0, "min": 0, "max": 100, "tooltip": "Top-k token filter. Set 0 to disable.", "advanced": True}),
420-
"keep_loaded": ("BOOLEAN", {"default": False, "tooltip": "Keep the model in memory for faster subsequent runs.", "advanced": True}),
421432
},
422433
}
423434
# fmt: on
@@ -428,41 +439,26 @@ def INPUT_TYPES(cls):
428439
CATEGORY = "JoyCaption"
429440
DESCRIPTION = "Runs JoyCaption on the input image to generate a caption. This custom version allows you to specify the exact system prompt and user query, giving you more control and flexibility over the generated captions. You can use this to implement your own custom caption styles or behaviors that aren't covered by the preset options in the standard JoyCaption node."
430441

431-
def __init__(self):
432-
self.predictor = None
433-
434442
def generate(
435443
self,
436-
image,
437-
memory_mode,
444+
model: JoyCaptionPredictor,
445+
image: torch.Tensor,
438446
system_prompt: str,
439447
user_query: str,
440448
max_new_tokens: int,
441449
temperature: float,
442450
top_p: float,
443451
top_k: int,
444-
keep_loaded: bool,
445452
):
446453
if image.shape[0] != 1:
447454
return ("Error: batch size greater than 1 is not supported.",)
448455

449-
if self.predictor is None or self.predictor.memory_mode != memory_mode:
450-
if self.predictor is not None:
451-
self.predictor.unload()
452-
del self.predictor
453-
self.predictor = None
454-
455-
try:
456-
self.predictor = JoyCaptionPredictor("fancyfeast/llama-joycaption-beta-one-hf-llava", memory_mode)
457-
except Exception as e:
458-
return (f"Error loading model: {e}",)
459-
460456
# This is a bit silly. We get the image as a tensor, and we could just use that directly (just need to resize and adjust the normalization).
461457
# But JoyCaption was trained on images that were resized using lanczos, which I think PyTorch doesn't support.
462458
# Just to be safe, we'll convert the image to a PIL image and let the processor handle it correctly.
463-
pil_image = ToPILImage()(image[0].permute(2, 0, 1))
459+
pil_image = TVF.to_pil_image(image[0].permute(2, 0, 1))
464460
try:
465-
response = self.predictor.generate(
461+
response = model.generate(
466462
image=pil_image,
467463
system=system_prompt,
468464
prompt=user_query,
@@ -472,7 +468,7 @@ def generate(
472468
top_k=top_k,
473469
)
474470
finally:
475-
self.predictor.cleanup_after_inference(keep_loaded=keep_loaded)
471+
model.cleanup_after_inference()
476472

477473
return (response,)
478474

0 commit comments

Comments
 (0)