Skip to content

Commit 6a1b383

Browse files
committed
added the GeminiPromptEnhance node.
1 parent cae3c6d commit 6a1b383

File tree

4 files changed

+216
-2
lines changed

4 files changed

+216
-2
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ This node, ColorCorrectOfUtils, is an extension of the original [ColorCorrect](h
3939
## ModifyTextGender
4040
This node adjusts the text to describe the gender based on the input. If the gender input is 'M', the text will be adjusted to describe as male; if the gender input is 'F', it will be adjusted to describe as female.
4141

42+
## GeminiPromptEnhance
43+
This node is designed to enhance the text description of the image, using the latest Gemini 2.0 flash model. It can add quality descriptors, lighting descriptions, scene descriptions, and skin descriptions to the text. and according to the gender input, can modifiy the content about gender.
44+
45+
use this node, must get your free API key from Google AI Studio:
46+
- Visit [Google AI Studio](https://aistudio.google.com/prompts/new_chat)
47+
- Log in with your Google account
48+
- Click on "Get API key" or go to settings
49+
- Create a new API key
50+
- Copy the API key for use in the node's input or gemini_config.json
51+
52+
this code is original from https://github.com/ShmuelRonen/ComfyUI-Gemini_Flash_2.0_Exp, added new features. thanks to @ShmuelRonen.
53+
4254
## GenderControlOutput
4355
This node determines the output based on the input gender. If the gender input is 'M', it will output male-specific text, float, and integer values. If the gender input is 'F', it will output female-specific text, float, and integer values.
4456

py/node_gemini_enhance_prompte.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# this code is original from https://github.com/ShmuelRonen/ComfyUI-Gemini_Flash_2.0_Exp, added cache and gender support
2+
import os
3+
import json
4+
import google.generativeai as genai
5+
from contextlib import contextmanager
6+
from collections import OrderedDict
7+
import folder_paths
8+
import logging
9+
import yaml
10+
logger = logging.getLogger(__name__)
11+
12+
config_dir = os.path.join(folder_paths.base_path, "config")
13+
if not os.path.exists(config_dir):
14+
os.makedirs(config_dir)
15+
16+
17+
def get_config():
18+
try:
19+
config_path = os.path.join(config_dir, 'gemini_config.yml')
20+
with open(config_path, 'r') as f:
21+
config = yaml.load(f, Loader=yaml.FullLoader)
22+
return config
23+
except:
24+
return {}
25+
26+
def save_config(config):
27+
config_path = os.path.join(config_dir, 'gemini_config.yml')
28+
with open(config_path, 'w') as f:
29+
yaml.dump(config, f, indent=4)
30+
31+
@contextmanager
32+
def temporary_env_var(key: str, new_value):
33+
old_value = os.environ.get(key)
34+
if new_value is not None:
35+
os.environ[key] = new_value
36+
elif key in os.environ:
37+
del os.environ[key]
38+
try:
39+
yield
40+
finally:
41+
if old_value is not None:
42+
os.environ[key] = old_value
43+
elif key in os.environ:
44+
del os.environ[key]
45+
46+
class LRUCache(OrderedDict):
47+
def __init__(self, capacity):
48+
super().__init__()
49+
self.capacity = capacity
50+
51+
def get(self, key):
52+
if key not in self:
53+
return None
54+
self.move_to_end(key)
55+
return self[key]
56+
57+
def put(self, key, value):
58+
if key in self:
59+
self.move_to_end(key)
60+
self[key] = value
61+
if len(self) > self.capacity:
62+
self.popitem(last=False)
63+
64+
class GeminiPromptEnhance:
65+
def __init__(self, api_key=None, proxy=None):
66+
config = get_config()
67+
self.api_key = api_key or config.get("GEMINI_API_KEY")
68+
self.proxy = proxy or config.get("PROXY")
69+
self.cache_size = 500 # 缓存最大条数
70+
self.cache_file = os.path.join(config_dir, 'prompt_cache_gemini.yml')
71+
self.cache = LRUCache(self.cache_size)
72+
self.load_cache()
73+
if self.api_key is not None:
74+
self.configure_genai()
75+
76+
def load_cache(self):
77+
try:
78+
if os.path.exists(self.cache_file):
79+
with open(self.cache_file, 'r', encoding='utf-8') as f:
80+
cache_data = yaml.load(f, Loader=yaml.FullLoader)
81+
# 重新创建LRU缓存
82+
for k, v in cache_data.items():
83+
self.cache.put(k, v)
84+
except Exception as e:
85+
logger.error(f"加载缓存出错: {str(e)}")
86+
self.cache = LRUCache(self.cache_size)
87+
88+
def save_cache(self):
89+
try:
90+
with open(self.cache_file, 'w', encoding='utf-8') as f:
91+
yaml.dump(dict(self.cache), f, indent=4)
92+
except Exception as e:
93+
logger.error(f"保存缓存出错: {str(e)}")
94+
95+
def configure_genai(self):
96+
genai.configure(api_key=self.api_key, transport='rest')
97+
98+
@classmethod
99+
def INPUT_TYPES(cls):
100+
default_prompt = "Edit and enhance the text description of the image. \nAdd quality descriptors, like 'A high-quality photo, an 8K photo.' \nAdd lighting descriptions based on the scene, like 'The lighting is natural and bright, casting soft shadows.' \nAdd scene descriptions according to the context, like 'The overall mood is serene and peaceful.' \nIf a person is in the scene, include a description of the skin, such as 'natural skin tones and ensure the skin appears realistic with clear, fine details.' \n\nOnly output the result of the text, no others.\nthe text is:"
101+
102+
return {
103+
"required": {
104+
"prompt": ("STRING", {"default": default_prompt, "multiline": True}),
105+
},
106+
"optional": {
107+
"text_input": ("STRING", {"default": "", "multiline": True}),
108+
"api_key": ("STRING", {"default": ""}),
109+
"proxy": ("STRING", {"default": ""}),
110+
"max_output_tokens": ("INT", {"default": 8192, "min": 1, "max": 8192}),
111+
"temperature": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.1}),
112+
"gender": (["","M", "F"], {"default": ""}),
113+
"enabled": ("BOOLEAN", {"default": True}),
114+
}
115+
}
116+
117+
RETURN_TYPES = ("STRING",)
118+
RETURN_NAMES = ("generated_content",)
119+
FUNCTION = "generate_content"
120+
CATEGORY = "utils/text"
121+
122+
def prepare_content(self, prompt, text_input, gender=""):
123+
if gender == "M":
124+
prompt = "edit and enhance the text content according to male gender. if there is a female, must change the text to describe as male.\n" + prompt
125+
elif gender == "F":
126+
prompt = "edit and enhance the text content according to female gender. if there is a male, must change the text to describe as female.\n" + prompt
127+
128+
text_content = prompt if not text_input else f"{prompt}\n{text_input}"
129+
return [{"text": text_content}]
130+
131+
def generate_content(self, prompt, text_input=None, api_key="", proxy="",
132+
max_output_tokens=8192, temperature=0.4, gender="", enabled=True):
133+
if not enabled:
134+
return (text_input,)
135+
136+
# 生成缓存键
137+
cache_key = f"{text_input or ''}_{gender}"
138+
139+
# 检查缓存
140+
cached_result = self.cache.get(cache_key)
141+
if cached_result is not None:
142+
return (cached_result,)
143+
144+
# Set all safety settings to block_none by default
145+
safety_settings = [
146+
{"category": "harassment", "threshold": "NONE"},
147+
{"category": "hate_speech", "threshold": "NONE"},
148+
{"category": "sexually_explicit", "threshold": "NONE"},
149+
{"category": "dangerous_content", "threshold": "NONE"},
150+
{"category": "civic", "threshold": "NONE"}
151+
]
152+
153+
# Only update API key if explicitly provided in the node
154+
if api_key.strip():
155+
self.api_key = api_key
156+
save_config({"GEMINI_API_KEY": self.api_key, "PROXY": self.proxy})
157+
self.configure_genai()
158+
159+
# Only update proxy if explicitly provided in the node
160+
if proxy.strip():
161+
self.proxy = proxy
162+
save_config({"GEMINI_API_KEY": self.api_key, "PROXY": self.proxy})
163+
164+
if not self.api_key:
165+
raise ValueError("API key not found in config.json or node input")
166+
167+
model_name = 'models/gemini-2.0-flash-exp'
168+
model = genai.GenerativeModel(model_name)
169+
170+
# Apply fixed safety settings to the model
171+
model.safety_settings = safety_settings
172+
173+
generation_config = genai.types.GenerationConfig(
174+
max_output_tokens=max_output_tokens,
175+
temperature=temperature
176+
)
177+
178+
with temporary_env_var('HTTP_PROXY', self.proxy), temporary_env_var('HTTPS_PROXY', self.proxy):
179+
try:
180+
content_parts = self.prepare_content(prompt, text_input, gender)
181+
response = model.generate_content(content_parts, generation_config=generation_config)
182+
generated_content = response.text
183+
184+
# 更新缓存
185+
self.cache.put(cache_key, generated_content)
186+
self.save_cache()
187+
188+
except Exception as e:
189+
generated_content = f"Error: {str(e)}"
190+
191+
return (generated_content,)
192+
193+
NODE_CLASS_MAPPINGS = {
194+
"GeminiPromptEnhance": GeminiPromptEnhance,
195+
}
196+
197+
NODE_DISPLAY_NAME_MAPPINGS = {
198+
"GeminiPromptEnhance": "Gemini prompt enhance",
199+
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-utils-nodes"
33
description = "Nodes:LoadImageWithSwitch, ImageBatchOneOrMore, ModifyTextGender, GenderControlOutput, ImageCompositeMaskedWithSwitch, ImageCompositeMaskedOneByOne, ColorCorrectOfUtils, SplitMask, MaskFastGrow, CheckpointLoaderSimpleWithSwitch, ImageResizeTo8x, MatchImageRatioToPreset, UpscaleImageWithModelIfNeed, MaskFromFaceModel, MaskCoverFourCorners, DetectorForNSFW, DeepfaceAnalyzeFaceAttributes etc."
4-
version = "1.2.6"
4+
version = "1.2.7"
55
license = { file = "LICENSE" }
66
dependencies = []
77

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@ onnxruntime>=1.19.2
77
# DeepfaceAnalyzeFaceAttributes
88
deepface==0.0.93
99
ultralytics
10-
tf-keras
10+
tf-keras
11+
12+
# Gemini_prompt_enhance nod
13+
google-generativeai>0.4.1

0 commit comments

Comments
 (0)