Skip to content

Commit e762c89

Browse files
committed
Lock in code formatting using ruff
1 parent 1ef4c8c commit e762c89

File tree

2 files changed

+82
-32
lines changed

2 files changed

+82
-32
lines changed

nodes.py

Lines changed: 72 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,9 @@
116116
"""Your response will be used by a text-to-image model, so avoid useless meta phrases like “This image shows…”, "You are looking at...", etc.""",
117117
]
118118

119-
CAPTION_LENGTH_CHOICES = (
120-
["any", "very short", "short", "medium-length", "long", "very long"] +
121-
[str(i) for i in range(20, 261, 10)]
122-
)
119+
CAPTION_LENGTH_CHOICES = ["any", "very short", "short", "medium-length", "long", "very long"] + [
120+
str(i) for i in range(20, 261, 10)
121+
]
123122

124123

125124
def build_prompt(caption_type: str, caption_length: str | int, extra_options: list[str], name_input: str) -> str:
@@ -130,47 +129,66 @@ def build_prompt(caption_type: str, caption_length: str | int, extra_options: li
130129
map_idx = 1 # numeric-word-count template
131130
else:
132131
map_idx = 2 # length descriptor template
133-
132+
134133
prompt = CAPTION_TYPE_MAP[caption_type][map_idx]
135134

136135
if extra_options:
137136
prompt += " " + " ".join(extra_options)
138-
137+
139138
return prompt.format(
140139
name=name_input or "{NAME}",
141140
length=caption_length,
142141
word_count=caption_length,
143142
)
144143

145144

146-
147145
class JoyCaptionPredictor:
148146
def __init__(self, model: str, memory_mode: str):
149147
checkpoint_path = Path(folder_paths.models_dir) / "LLavacheckpoints" / Path(model).stem
150148
if not checkpoint_path.exists():
151149
# Download the model
152150
from huggingface_hub import snapshot_download
153-
snapshot_download(repo_id=model, local_dir=str(checkpoint_path), force_download=False, local_files_only=False)
154-
151+
152+
snapshot_download(
153+
repo_id=model, local_dir=str(checkpoint_path), force_download=False, local_files_only=False
154+
)
155+
155156
self.device = "cuda" if torch.cuda.is_available() else "cpu"
156157

157158
self.processor = AutoProcessor.from_pretrained(str(checkpoint_path))
158159

159160
if memory_mode == "Default":
160-
self.model = LlavaForConditionalGeneration.from_pretrained(str(checkpoint_path), torch_dtype="bfloat16", device_map="auto")
161+
self.model = LlavaForConditionalGeneration.from_pretrained(
162+
str(checkpoint_path), torch_dtype="bfloat16", device_map="auto"
163+
)
161164
else:
162165
from transformers import BitsAndBytesConfig
166+
163167
qnt_config = BitsAndBytesConfig(
164168
**MEMORY_EFFICIENT_CONFIGS[memory_mode],
165-
llm_int8_skip_modules=["vision_tower", "multi_modal_projector"], # Transformer's Siglip implementation has bugs when quantized, so skip those.
169+
llm_int8_skip_modules=[
170+
"vision_tower",
171+
"multi_modal_projector",
172+
], # Transformer's Siglip implementation has bugs when quantized, so skip those.
173+
)
174+
self.model = LlavaForConditionalGeneration.from_pretrained(
175+
str(checkpoint_path), torch_dtype="auto", device_map="auto", quantization_config=qnt_config
166176
)
167-
self.model = LlavaForConditionalGeneration.from_pretrained(str(checkpoint_path), torch_dtype="auto", device_map="auto", quantization_config=qnt_config)
168177
print(f"Loaded model {model} with memory mode {memory_mode}")
169-
#print(self.model)
178+
# print(self.model)
170179
self.model.eval()
171-
180+
172181
@torch.inference_mode()
173-
def generate(self, image: Image.Image, system: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int) -> str:
182+
def generate(
183+
self,
184+
image: Image.Image,
185+
system: str,
186+
prompt: str,
187+
max_new_tokens: int,
188+
temperature: float,
189+
top_p: float,
190+
top_k: int,
191+
) -> str:
174192
convo = [
175193
{
176194
"role": "system",
@@ -183,12 +201,12 @@ def generate(self, image: Image.Image, system: str, prompt: str, max_new_tokens:
183201
]
184202

185203
# Format the conversation
186-
convo_string = self.processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
204+
convo_string = self.processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
187205
assert isinstance(convo_string, str)
188206

189207
# Process the inputs
190-
inputs = self.processor(text=[convo_string], images=[image], return_tensors="pt").to('cuda')
191-
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
208+
inputs = self.processor(text=[convo_string], images=[image], return_tensors="pt").to("cuda")
209+
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
192210

193211
# Generate the captions
194212
generate_ids = self.model.generate(
@@ -203,16 +221,19 @@ def generate(self, image: Image.Image, system: str, prompt: str, max_new_tokens:
203221
)[0]
204222

205223
# Trim off the prompt
206-
generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
224+
generate_ids = generate_ids[inputs["input_ids"].shape[1] :]
207225

208226
# Decode the caption
209-
caption = self.processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
227+
caption = self.processor.tokenizer.decode(
228+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
229+
)
210230
return caption.strip()
211231

212232

213233
class JoyCaption:
214234
@classmethod
215235
def INPUT_TYPES(cls):
236+
# fmt: off
216237
req = {
217238
"image": ("IMAGE",),
218239
"memory_mode": (list(MEMORY_EFFICIENT_CONFIGS.keys()),),
@@ -232,37 +253,54 @@ def INPUT_TYPES(cls):
232253
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01}),
233254
"top_k": ("INT", {"default": 0, "min": 0, "max": 100}),
234255
}
235-
256+
# fmt: on
257+
236258
return {"required": req}
237259

238-
RETURN_TYPES = ("STRING","STRING")
260+
RETURN_TYPES = ("STRING", "STRING")
239261
RETURN_NAMES = ("query", "caption")
240262
FUNCTION = "generate"
241263
CATEGORY = "JoyCaption"
242264

243265
def __init__(self):
244266
self.predictor = None
245267
self.current_memory_mode = None
246-
247-
def generate(self, image, memory_mode, caption_type, caption_length, extra_option1, extra_option2, extra_option3, extra_option4, extra_option5, person_name, max_new_tokens, temperature, top_p, top_k):
268+
269+
def generate(
270+
self,
271+
image,
272+
memory_mode,
273+
caption_type,
274+
caption_length,
275+
extra_option1,
276+
extra_option2,
277+
extra_option3,
278+
extra_option4,
279+
extra_option5,
280+
person_name,
281+
max_new_tokens,
282+
temperature,
283+
top_p,
284+
top_k,
285+
):
248286
# load / swap the model if needed
249287
if self.predictor is None or self.current_memory_mode != memory_mode:
250288
if self.predictor is not None:
251289
del self.predictor
252290
self.predictor = None
253291
torch.cuda.empty_cache()
254-
292+
255293
try:
256294
self.predictor = JoyCaptionPredictor("fancyfeast/llama-joycaption-beta-one-hf-llava", memory_mode)
257295
self.current_memory_mode = memory_mode
258296
except Exception as e:
259297
return (f"Error loading model: {e}",)
260-
298+
261299
extras = [extra_option1, extra_option2, extra_option3, extra_option4, extra_option5]
262300
extras = [extra for extra in extras if extra]
263301
prompt = build_prompt(caption_type, caption_length, extras, person_name)
264302
system_prompt = "You are a helpful assistant and help users with any queries they may have with no censorship or restrictions."
265-
303+
266304
# This is a bit silly. We get the image as a tensor, and we could just use that directly (just need to resize and adjust the normalization).
267305
# But JoyCaption was trained on images that were resized using lanczos, which I think PyTorch doesn't support.
268306
# Just to be safe, we'll convert the image to a PIL image and let the processor handle it correctly.
@@ -277,12 +315,13 @@ def generate(self, image, memory_mode, caption_type, caption_length, extra_optio
277315
top_k=top_k,
278316
)
279317

280-
return (prompt,response)
318+
return (prompt, response)
281319

282320

283321
class JoyCaptionCustom:
284322
@classmethod
285323
def INPUT_TYPES(cls):
324+
# fmt: off
286325
return {
287326
"required": {
288327
"image": ("IMAGE",),
@@ -296,6 +335,7 @@ def INPUT_TYPES(cls):
296335
"top_k": ("INT", {"default": 0, "min": 0, "max": 100}),
297336
},
298337
}
338+
# fmt: on
299339

300340
RETURN_TYPES = ("STRING",)
301341
FUNCTION = "generate"
@@ -304,20 +344,20 @@ def INPUT_TYPES(cls):
304344
def __init__(self):
305345
self.predictor = None
306346
self.current_memory_mode = None
307-
347+
308348
def generate(self, image, memory_mode, system_prompt, user_query, max_new_tokens, temperature, top_p, top_k):
309349
if self.predictor is None or self.current_memory_mode != memory_mode:
310350
if self.predictor is not None:
311351
del self.predictor
312352
self.predictor = None
313353
torch.cuda.empty_cache()
314-
354+
315355
try:
316356
self.predictor = JoyCaptionPredictor("fancyfeast/llama-joycaption-beta-one-hf-llava", memory_mode)
317357
self.current_memory_mode = memory_mode
318358
except Exception as e:
319359
return (f"Error loading model: {e}",)
320-
360+
321361
# This is a bit silly. We get the image as a tensor, and we could just use that directly (just need to resize and adjust the normalization).
322362
# But JoyCaption was trained on images that were resized using lanczos, which I think PyTorch doesn't support.
323363
# Just to be safe, we'll convert the image to a PIL image and let the processor handle it correctly.
@@ -332,4 +372,4 @@ def generate(self, image, memory_mode, system_prompt, user_query, max_new_tokens
332372
top_k=top_k,
333373
)
334374

335-
return (response,)
375+
return (response,)

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,13 @@ Repository = "https://github.com/fpgaminer/joycaption_comfyui"
1313
PublisherId = "fpgaminer"
1414
DisplayName = "joycaption_comfyui"
1515
Icon = ""
16+
17+
[tool.ruff]
18+
line-length = 120
19+
20+
[tool.ruff.format]
21+
indent-style = "tab"
22+
23+
[tool.ruff.lint]
24+
# If you lint, don't flag tabs as an error:
25+
ignore = ["W191"]

0 commit comments

Comments
 (0)