Skip to content

Commit 30be4bd

Browse files
committed
Add proper seed based cache so so that when set to fixed seed and no changes made, it will not send additional requests
1 parent c7cffb7 commit 30be4bd

File tree

1 file changed

+80
-2
lines changed

1 file changed

+80
-2
lines changed

gemini_nodes.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,36 @@ def configure(self, api_key, api_version, vertexai, vertexai_project, vertexai_l
159159

160160

161161
class SSL_GeminiTextPrompt(ComfyNodeABC):
162+
def __init__(self):
163+
super().__init__()
164+
# Initialize cache state
165+
self.last_input_seed = None
166+
self.last_actual_seed = None
167+
self.last_text_output = None
168+
self.last_image_tensor = None
169+
self.last_config = None
170+
self.last_prompt = None
171+
self.last_system_instruction = None
172+
self.last_model = None
173+
self.last_temperature = None
174+
self.last_top_p = None
175+
self.last_top_k = None
176+
self.last_max_output_tokens = None
177+
self.last_include_images = None
178+
self.last_aspect_ratio = None
179+
self.last_bypass_mode = None
180+
self.last_thinking_budget = None
181+
self.last_input_image = None
182+
self.last_input_image_2 = None
183+
162184
@classmethod
163185
def INPUT_TYPES(cls) -> InputTypeDict:
164186
return {
165187
"required": {
166188
"config": ("GEMINI_CONFIG",),
167189
"prompt": (IO.STRING, {"multiline": True}),
168190
"system_instruction": (IO.STRING, {"default": "You are a helpful AI assistant.", "multiline": True}),
169-
"model": (["gemini-1.0-pro", "gemini-exp-1206", "gemini-1.5-flash", "gemini-1.5-pro", "gemini-2.0-flash", "gemini-2.0-flash-lite-001", "gemini-2.0-flash-exp", "gemini-2.0-pro", "gemini-2.0-flash-live", "gemini-2.5-pro", "gemini-2.5-pro-preview-05-06", "gemini-2.5-flash", "gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-lite-preview-06-17", "gemini-2.5-flash-image-preview"], {"default": "gemini-2.0-flash"}),
191+
"model": (["gemini-1.0-pro", "gemini-exp-1206", "gemini-1.5-flash", "gemini-1.5-pro", "gemini-2.0-flash", "gemini-2.0-flash-lite-001", "gemini-2.0-flash-exp", "gemini-2.0-pro", "gemini-2.0-pro-exp", "gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.5-pro", "gemini-2.5-pro-preview-05-06", "gemini-2.5-flash", "gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-lite-preview-06-17", "gemini-2.5-flash-image-preview"], {"default": "gemini-2.0-flash"}),
170192
"temperature": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
171193
"top_p": (IO.FLOAT, {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}),
172194
"top_k": (IO.INT, {"default": 40, "min": 1, "max": 100, "step": 1}),
@@ -253,6 +275,40 @@ def generate_empty_image(self, width=64, height=64):
253275
def generate(self, config, prompt, system_instruction, model, temperature, top_p, top_k, max_output_tokens, include_images,
254276
aspect_ratio, bypass_mode, thinking_budget, input_image=None, input_image_2=None,
255277
use_proxy=False, proxy_host="127.0.0.1", proxy_port=7890, use_seed=False, seed=0):
278+
279+
# Helper for comparing optional tensors
280+
def compare_tensors(t1, t2):
281+
if t1 is None and t2 is None:
282+
return True
283+
if t1 is not None and t2 is not None:
284+
return torch.equal(t1, t2)
285+
return False
286+
287+
# Comprehensive cache check
288+
is_cached = (
289+
use_seed and
290+
self.last_image_tensor is not None and
291+
self.last_input_seed == seed and
292+
self.last_config == config and
293+
self.last_prompt == prompt and
294+
self.last_system_instruction == system_instruction and
295+
self.last_model == model and
296+
self.last_temperature == temperature and
297+
self.last_top_p == top_p and
298+
self.last_top_k == top_k and
299+
self.last_max_output_tokens == max_output_tokens and
300+
self.last_include_images == include_images and
301+
self.last_aspect_ratio == aspect_ratio and
302+
self.last_bypass_mode == bypass_mode and
303+
self.last_thinking_budget == thinking_budget and
304+
compare_tensors(self.last_input_image, input_image) and
305+
compare_tensors(self.last_input_image_2, input_image_2)
306+
)
307+
308+
if is_cached:
309+
print(f"[INFO] All inputs match the previous run with seed ({seed}). Returning cached result.")
310+
return (self.last_text_output, self.last_image_tensor, self.last_actual_seed)
311+
256312
original_http_proxy = os.environ.get('HTTP_PROXY')
257313
original_https_proxy = os.environ.get('HTTPS_PROXY')
258314
original_http_proxy_lower = os.environ.get('http_proxy')
@@ -670,7 +726,29 @@ def api_call():
670726
if image_tensor is None:
671727
image_tensor = self.generate_empty_image()
672728

673-
return (text_output, image_tensor, actual_seed if actual_seed is not None else 0)
729+
final_actual_seed = actual_seed if actual_seed is not None else 0
730+
if use_seed:
731+
# Update cache with all current inputs and outputs
732+
self.last_input_seed = seed
733+
self.last_actual_seed = final_actual_seed
734+
self.last_text_output = text_output
735+
self.last_image_tensor = image_tensor
736+
self.last_config = config
737+
self.last_prompt = prompt
738+
self.last_system_instruction = system_instruction
739+
self.last_model = model
740+
self.last_temperature = temperature
741+
self.last_top_p = top_p
742+
self.last_top_k = top_k
743+
self.last_max_output_tokens = max_output_tokens
744+
self.last_include_images = include_images
745+
self.last_aspect_ratio = aspect_ratio
746+
self.last_bypass_mode = bypass_mode
747+
self.last_thinking_budget = thinking_budget
748+
self.last_input_image = input_image.clone() if input_image is not None else None
749+
self.last_input_image_2 = input_image_2.clone() if input_image_2 is not None else None
750+
751+
return (text_output, image_tensor, final_actual_seed)
674752
finally:
675753
if original_http_proxy:
676754
os.environ['HTTP_PROXY'] = original_http_proxy

0 commit comments

Comments
 (0)