Skip to content

Commit 71107a8

Browse files
author
Ubuntu
committed
refactor comfy api io to cio to not break io.BytesIO functionality
1 parent a22e3bd commit 71107a8

File tree

1 file changed

+59
-58
lines changed

1 file changed

+59
-58
lines changed

gemini_nodes.py

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import numpy as np
77
import cv2
88
from PIL import Image
9-
import io as stdlib_io
9+
import io
1010
import folder_paths # type: ignore[reportMissingImports]
11-
from comfy_api.latest import ComfyExtension, io, ui # type: ignore[reportMissingImports]
11+
from comfy_api.latest import ComfyExtension, ui # type: ignore[reportMissingImports]
12+
from comfy_api.latest import io as cio
1213
from google import genai
1314
from google.genai import types
1415
import time
@@ -44,26 +45,26 @@ def check_and_install_dependencies():
4445
except Exception as e:
4546
print(f"[WARNING] Error checking dependencies: {str(e)}")
4647

47-
class GetKeyAPI(io.ComfyNode):
48+
class GetKeyAPI(cio.ComfyNode):
4849
@classmethod
49-
def define_schema(cls) -> io.Schema:
50-
return io.Schema(
50+
def define_schema(cls) -> cio.Schema:
51+
return cio.Schema(
5152
node_id="GetKeyAPI",
5253
display_name="Get API Key from JSON",
5354
category="utils/api_keys",
5455
inputs=[
55-
io.String.Input("json_path", default="./input/apikeys.json", multiline=False, tooltip="Path to a .json file with simple top level structure with name as key and api-key as value. See example in custom node folder."),
56-
io.Combo.Input("key_id_method", options=["custom", "random_rotate", "increment_rotate"], default="custom", tooltip="custom sets api-key to the api-key with the name set in the key_id widget. random_rotate randomly switches between keys if multiple in the .json and increment_rotate does it in order from first to last, then repeats."),
57-
io.Int.Input("rotation_interval", default=0, min=0, tooltip="how many steps to jump when doing rotate."),
58-
io.String.Input("key_id", default="placeholder", multiline=False, optional=True, tooltip="Put name of key in the .json here if using custom in key_id_method."),
56+
cio.String.Input("json_path", default="./input/apikeys.json", multiline=False, tooltip="Path to a .json file with simple top level structure with name as key and api-key as value. See example in custom node folder."),
57+
cio.Combo.Input("key_id_method", options=["custom", "random_rotate", "increment_rotate"], default="custom", tooltip="custom sets api-key to the api-key with the name set in the key_id widget. random_rotate randomly switches between keys if multiple in the .json and increment_rotate does it in order from first to last, then repeats."),
58+
cio.Int.Input("rotation_interval", default=0, min=0, tooltip="how many steps to jump when doing rotate."),
59+
cio.String.Input("key_id", default="placeholder", multiline=False, optional=True, tooltip="Put name of key in the .json here if using custom in key_id_method."),
5960
],
6061
outputs=[
61-
io.String.Output("API_KEY")
62+
cio.String.Output("API_KEY")
6263
]
6364
)
6465

6566
@classmethod
66-
def execute(cls, json_path: str, key_id_method: str, rotation_interval: int, key_id: str | None = "placeholder") -> io.NodeOutput:
67+
def execute(cls, json_path: str, key_id_method: str, rotation_interval: int, key_id: str | None = "placeholder") -> cio.NodeOutput:
6768
api_keys_data = None
6869
absolute_json_path = os.path.abspath(json_path)
6970

@@ -116,40 +117,40 @@ def execute(cls, json_path: str, key_id_method: str, rotation_interval: int, key
116117
raise ValueError(f"RotateKeyAPI Error: Retrieved value for selected key is not a valid string. Value: {selected_key_value}")
117118

118119
print(f"RotateKeyAPI: Successfully retrieved API key using method '{key_id_method}'.")
119-
return io.NodeOutput(selected_key_value)
120+
return cio.NodeOutput(selected_key_value)
120121

121122

122123

123-
class SSL_GeminiAPIKeyConfig(io.ComfyNode):
124-
GemConfig = io.Custom("GEMINI_CONFIG")
124+
class SSL_GeminiAPIKeyConfig(cio.ComfyNode):
125+
GemConfig = cio.Custom("GEMINI_CONFIG")
125126

126127
@classmethod
127-
def define_schema(cls) -> io.Schema:
128-
return io.Schema(
128+
def define_schema(cls) -> cio.Schema:
129+
return cio.Schema(
129130
node_id="SSL_GeminiAPIKeyConfig",
130131
display_name="Configure Gemini API Key",
131132
category="API/Gemini",
132133
inputs=[
133-
io.String.Input("api_key", multiline=False),
134-
io.Combo.Input("api_version", options=["v1", "v1alpha", "v1beta", "v2beta"], default="v1alpha"),
135-
io.Boolean.Input("vertexai", default=False),
136-
io.String.Input("vertexai_project", default="placeholder", optional=True),
137-
io.String.Input("vertexai_location", default="placeholder", optional=True),
134+
cio.String.Input("api_key", multiline=False),
135+
cio.Combo.Input("api_version", options=["v1", "v1alpha", "v1beta", "v2beta"], default="v1alpha"),
136+
cio.Boolean.Input("vertexai", default=False),
137+
cio.String.Input("vertexai_project", default="placeholder", optional=True),
138+
cio.String.Input("vertexai_location", default="placeholder", optional=True),
138139
],
139140
outputs=[
140141
cls.GemConfig.Output("config")
141142
]
142143
)
143144

144145
@classmethod
145-
def execute(cls, api_key: str, api_version: str, vertexai: bool, vertexai_project: str | None = "placeholder", vertexai_location: str | None = "placeholder") -> io.NodeOutput:
146+
def execute(cls, api_key: str, api_version: str, vertexai: bool, vertexai_project: str | None = "placeholder", vertexai_location: str | None = "placeholder") -> cio.NodeOutput:
146147
config = {"api_key": api_key, "api_version": api_version, "vertexai": vertexai, "vertexai_project": vertexai_project, "vertexai_location": vertexai_location}
147-
return io.NodeOutput(config)
148+
return cio.NodeOutput(config)
148149

149150

150151

151-
class SSL_GeminiTextPrompt(io.ComfyNode):
152-
GemConfig = io.Custom("GEMINI_CONFIG")
152+
class SSL_GeminiTextPrompt(cio.ComfyNode):
153+
GemConfig = cio.Custom("GEMINI_CONFIG")
153154
_cache: dict = {}
154155

155156
# Define model lists centrally to ensure consistency between cache logic and execution logic
@@ -169,40 +170,40 @@ class SSL_GeminiTextPrompt(io.ComfyNode):
169170
]
170171

171172
@classmethod
172-
def define_schema(cls) -> io.Schema:
173-
return io.Schema(
173+
def define_schema(cls) -> cio.Schema:
174+
return cio.Schema(
174175
node_id="SSL_GeminiTextPrompt",
175176
display_name="Expanded Gemini Text/Image",
176177
category="API/Gemini",
177178
inputs=[
178179
cls.GemConfig.Input("config"),
179-
io.String.Input("prompt", multiline=True),
180-
io.String.Input("system_instruction", default="You are a helpful AI assistant.", multiline=True),
181-
io.Combo.Input("model", options=["learnlm-2.0-flash-experimental", "gemini-exp-1206", "gemini-2.0-flash", "gemini-2.0-flash-lite-001", "gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-flash-thinking-exp-1219", "gemini-2.5-pro", "gemini-2.5-pro-preview-05-06", "gemini-2.5-flash", "gemini-2.5-flash-preview-09-2025", "gemini-3-pro-preview", "gemini-2.5-flash-image-preview"], default="gemini-2.0-flash"),
182-
io.Float.Input("temperature", default=1.0, min=0.0, max=1.0, step=0.01),
183-
io.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01),
184-
io.Int.Input("top_k", default=40, min=1, max=100, step=1),
185-
io.Int.Input("max_output_tokens", default=8192, min=1, max=65536, step=1),
186-
io.Boolean.Input("include_images", default=False),
187-
io.Combo.Input("aspect_ratio", options=["None", "1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5", "21:9"], default="None"),
188-
io.Combo.Input("bypass_mode", options=["None", "system_instruction", "prompt", "both"], default="None"),
189-
io.Int.Input("thinking_budget", default=0, min=-1, max=24576, step=1, tooltip="0 disables thinking mode, -1 will activate it as default dynamic thinking and anything above 0 sets specific budget"),
190-
io.Image.Input("input_image", optional=True),
191-
io.Image.Input("input_image_2", optional=True),
192-
io.Boolean.Input("use_proxy", default=False),
193-
io.String.Input("proxy_host", default="127.0.0.1"),
194-
io.Int.Input("proxy_port", default=7890, min=1, max=65535),
195-
io.Boolean.Input("use_seed", default=True),
196-
io.Int.Input("seed", default=0, min=0, max=2147483647),
197-
io.Int.Input("timeout", default=30, min=15, max=300, step=15),
198-
io.Boolean.Input("include_thoughts", default=False),
199-
io.Combo.Input("thinking_level", options=["None", "low", "medium", "high"], default="None", tooltip="Does not work at the same time as 'thinking_budget'. if this is set, then thinking budget is ignored."),
200-
io.Combo.Input("media_resolution", options=["unspecified", "low", "medium", "high"], default="unspecified", tooltip="Set input media resolution for image, video and pdf. This changes tokens consumed."),
180+
cio.String.Input("prompt", multiline=True),
181+
cio.String.Input("system_instruction", default="You are a helpful AI assistant.", multiline=True),
182+
cio.Combo.Input("model", options=["learnlm-2.0-flash-experimental", "gemini-exp-1206", "gemini-2.0-flash", "gemini-2.0-flash-lite-001", "gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-flash-thinking-exp-1219", "gemini-2.5-pro", "gemini-2.5-pro-preview-05-06", "gemini-2.5-flash", "gemini-2.5-flash-preview-09-2025", "gemini-3-pro-preview", "gemini-2.5-flash-image-preview"], default="gemini-2.0-flash"),
183+
cio.Float.Input("temperature", default=1.0, min=0.0, max=1.0, step=0.01),
184+
cio.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01),
185+
cio.Int.Input("top_k", default=40, min=1, max=100, step=1),
186+
cio.Int.Input("max_output_tokens", default=8192, min=1, max=65536, step=1),
187+
cio.Boolean.Input("include_images", default=False),
188+
cio.Combo.Input("aspect_ratio", options=["None", "1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5", "21:9"], default="None"),
189+
cio.Combo.Input("bypass_mode", options=["None", "system_instruction", "prompt", "both"], default="None"),
190+
cio.Int.Input("thinking_budget", default=0, min=-1, max=24576, step=1, tooltip="0 disables thinking mode, -1 will activate it as default dynamic thinking and anything above 0 sets specific budget"),
191+
cio.Image.Input("input_image", optional=True),
192+
cio.Image.Input("input_image_2", optional=True),
193+
cio.Boolean.Input("use_proxy", default=False),
194+
cio.String.Input("proxy_host", default="127.0.0.1"),
195+
cio.Int.Input("proxy_port", default=7890, min=1, max=65535),
196+
cio.Boolean.Input("use_seed", default=True),
197+
cio.Int.Input("seed", default=0, min=0, max=2147483647),
198+
cio.Int.Input("timeout", default=30, min=15, max=300, step=15),
199+
cio.Boolean.Input("include_thoughts", default=False),
200+
cio.Combo.Input("thinking_level", options=["None", "low", "medium", "high"], default="None", tooltip="Does not work at the same time as 'thinking_budget'. if this is set, then thinking budget is ignored."),
201+
cio.Combo.Input("media_resolution", options=["unspecified", "low", "medium", "high"], default="unspecified", tooltip="Set input media resolution for image, video and pdf. This changes tokens consumed."),
201202
],
202203
outputs=[
203-
io.String.Output("text"),
204-
io.Image.Output("image"),
205-
io.Int.Output("final_actual_seed")
204+
cio.String.Output("text"),
205+
cio.Image.Output("image"),
206+
cio.Int.Output("final_actual_seed")
206207
]
207208
)
208209

@@ -448,7 +449,7 @@ def _build_generate_content_config(cls, model, temperature, top_p, top_k, max_ou
448449
def execute(cls, config, prompt, system_instruction, model, temperature, top_p, top_k, max_output_tokens,
449450
include_images, aspect_ratio, bypass_mode, thinking_budget, input_image=None, input_image_2=None,
450451
use_proxy=False, proxy_host="127.0.0.1", proxy_port=7890, use_seed=False, seed=0, timeout=30,
451-
include_thoughts=False, thinking_level=None, media_resolution=None) -> io.NodeOutput:
452+
include_thoughts=False, thinking_level=None, media_resolution=None) -> cio.NodeOutput:
452453

453454
fingerprint, cached = cls._compute_fingerprint_and_check_cache(
454455
config, prompt, system_instruction, model, temperature, top_p, top_k, max_output_tokens,
@@ -461,7 +462,7 @@ def execute(cls, config, prompt, system_instruction, model, temperature, top_p,
461462
if cached is not None:
462463
cached_text, cached_image, cached_seed = cached
463464
print(f"[INFO] Returning cached result for fingerprint {fingerprint}")
464-
return io.NodeOutput(cached_text, cached_image, cached_seed)
465+
return cio.NodeOutput(cached_text, cached_image, cached_seed)
465466

466467
# --- keep most of the original implementation but converted to classmethod usage ---
467468
original_http_proxy = os.environ.get('HTTP_PROXY')
@@ -538,7 +539,7 @@ def add_headers(self, request, **kwargs):
538539
client = genai.Client(api_key=config.get("api_key"), http_options=types.HttpOptions(api_version=config.get("api_version")), **client_options) # type: ignore
539540
except Exception as e:
540541
print(f"[ERROR] Gemini client initialization failed: {str(e)}")
541-
return io.NodeOutput(f"Gemini client initialization failed: {str(e)}", cls.generate_empty_image(), actual_seed if actual_seed is not None else 0)
542+
return cio.NodeOutput(f"Gemini client initialization failed: {str(e)}", cls.generate_empty_image(), actual_seed if actual_seed is not None else 0)
542543

543544
# Network test (best-effort)
544545
try:
@@ -574,7 +575,7 @@ def add_headers(self, request, **kwargs):
574575
img_array = img[0].cpu().numpy()
575576
img_array = (img_array * 255).astype(np.uint8)
576577
pil_img = Image.fromarray(img_array)
577-
img_byte_arr = stdlib_io.BytesIO()
578+
img_byte_arr = io.BytesIO()
578579
pil_img.save(img_byte_arr, format='PNG')
579580
img_bytes = img_byte_arr.getvalue()
580581
if model in cls.MEDIA_RES_MODELS and media_resolution is not None and media_resolution != "unspecified":
@@ -585,7 +586,7 @@ def add_headers(self, request, **kwargs):
585586
contents = img_parts + [{"text": padded_prompt}]
586587
except Exception as e:
587588
print(f"[ERROR] Error processing input image: {str(e)}")
588-
return io.NodeOutput(f"Error processing input image: {str(e)}", cls.generate_empty_image(), actual_seed if actual_seed is not None else 0)
589+
return cio.NodeOutput(f"Error processing input image: {str(e)}", cls.generate_empty_image(), actual_seed if actual_seed is not None else 0)
589590
else:
590591
contents = padded_prompt
591592

@@ -714,7 +715,7 @@ def api_call():
714715
pass
715716

716717
try:
717-
return io.NodeOutput(text_output, image_tensor, final_actual_seed)
718+
return cio.NodeOutput(text_output, image_tensor, final_actual_seed)
718719
finally:
719720
if original_http_proxy:
720721
os.environ['HTTP_PROXY'] = original_http_proxy

0 commit comments

Comments
 (0)