Skip to content

Commit 2f65196

Browse files
committed
add the ReplicateRequstNode.
1 parent dcfae2f commit 2f65196

File tree

4 files changed

+157
-2
lines changed

4 files changed

+157
-2
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,6 @@ This node is designed to return an empty conditioning, the size is zero. It can
172172
173173
## CropByMaskToSpecificSize
174174
This node is designed to crop the image by the mask to a specific size.
175+
176+
## ReplicateRequstNode
177+
his node is designed to generate images using Replicate's model API. It supports various aspect ratios, LoRA weights, and provides flexible error handling options.

py/node_replicate.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import os
2+
import sys
3+
sys.path.append(".")
4+
import replicate
5+
import folder_paths
6+
import logging
7+
import yaml
8+
import numpy as np
9+
import requests
10+
from PIL import Image
11+
import io
12+
13+
logger = logging.getLogger(__name__)
14+
15+
config_dir = os.path.join(folder_paths.base_path, "config")
16+
if not os.path.exists(config_dir):
17+
os.makedirs(config_dir)
18+
19+
20+
def get_config():
21+
try:
22+
config_path = os.path.join(config_dir, 'replicate_config.yml')
23+
with open(config_path, 'r') as f:
24+
config = yaml.load(f, Loader=yaml.FullLoader)
25+
return config
26+
except:
27+
return {}
28+
29+
def save_config(config):
30+
config_path = os.path.join(config_dir, 'replicate_config.yml')
31+
with open(config_path, 'w') as f:
32+
yaml.dump(config, f, indent=4)
33+
34+
35+
class ReplicateRequstNode:
36+
def __init__(self, api_key=None):
37+
config = get_config()
38+
self.api_key = api_key or config.get("REPLICATE_API_TOKEN")
39+
if self.api_key is not None:
40+
self.configure_replicate()
41+
42+
def configure_replicate(self):
43+
if self.api_key:
44+
os.environ["REPLICATE_API_TOKEN"] = self.api_key
45+
46+
@classmethod
47+
def INPUT_TYPES(cls):
48+
return {
49+
"required": {
50+
"prompt": ("STRING", {"default": "style of 80s cyberpunk, a portrait photo", "multiline": True}),
51+
"seed": ("INT", {"default": 42, "min": 0, "max": 2147483647}),
52+
"aspect_ratio": (["1:1", "16:9", "21:9", "3:2", "4:3", "5:4", "9:16", "9:21", "2:3", "3:4", "4:5"], {"default": "1:1"}),
53+
"steps": ("INT", {"default": 28, "min": 1, "max": 100}),
54+
"guidance": ("FLOAT", {"default": 3.5, "min": 0.1, "max": 100.0, "step": 0.1}),
55+
"go_fast": ("BOOLEAN", {"default": True}),
56+
},
57+
"optional": {
58+
"api_key": ("STRING", {"default": ""}),
59+
"lora_path": ("STRING", {"default": ""}),
60+
"lora_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.01}),
61+
"extra_lora": ("STRING", {"default": ""}),
62+
"extra_lora_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.1}),
63+
"model": ("STRING", {"default": "black-forest-labs/flux-dev-lora"}),
64+
"num_outputs": ("INT", {"default": 1, "min": 1, "max": 10}),
65+
}
66+
}
67+
68+
RETURN_TYPES = ("IMAGE", "INT", "INT", "STRING")
69+
RETURN_NAMES = ("image", "width", "height", "url")
70+
FUNCTION = "generate_image"
71+
CATEGORY = "utils/image"
72+
73+
def generate_image(self, prompt, seed, aspect_ratio, steps, guidance, go_fast, lora_path="", lora_scale=1.0,
74+
api_key="", extra_lora="", extra_lora_scale=1.0, model="black-forest-labs/flux-dev-lora",
75+
num_outputs=1):
76+
# 更新API key
77+
if api_key.strip():
78+
self.api_key = api_key
79+
save_config({"REPLICATE_API_TOKEN": self.api_key})
80+
self.configure_replicate()
81+
82+
if not self.api_key:
83+
raise ValueError("API key not found in replicate_config.yml or node input")
84+
85+
try:
86+
# 准备输入参数
87+
input_params = {
88+
"prompt": prompt,
89+
"lora_weights": lora_path,
90+
"seed": seed,
91+
"aspect_ratio": aspect_ratio,
92+
"num_inference_steps": steps,
93+
"guidance": guidance,
94+
"go_fast": go_fast,
95+
"lora_scale": lora_scale,
96+
"output_format": "png",
97+
"num_outputs": num_outputs
98+
}
99+
100+
# 添加额外的LoRA参数
101+
if extra_lora.strip():
102+
input_params["extra_lora"] = extra_lora
103+
input_params["extra_lora_scale"] = extra_lora_scale
104+
105+
logger.debug(f"调用Replicate API,参数: {input_params}")
106+
107+
# 调用Replicate API
108+
output = replicate.run(model, input=input_params)
109+
110+
if not output or len(output) == 0:
111+
raise Exception("Replicate API返回空结果")
112+
113+
images = []
114+
urls = []
115+
for image_url in output:
116+
logger.debug(f"生成的图片URL: {image_url}")
117+
urls.append(str(image_url))
118+
# 下载图片
119+
response = requests.get(image_url)
120+
response.raise_for_status()
121+
122+
# 转换为PIL图像
123+
image = Image.open(io.BytesIO(response.content))
124+
width, height = image.size
125+
126+
image_array = np.array(image)
127+
if len(image_array.shape) == 3 and image_array.shape[2] == 4: # RGBA
128+
image_array = image_array[:, :, :3] # 转换为RGB
129+
images.append(image_array)
130+
131+
from .utils import np2tensor
132+
image_tensor = np2tensor(images)
133+
urls_str = ",".join(urls)
134+
135+
return (image_tensor, width, height, urls_str)
136+
137+
except Exception as e:
138+
logger.exception(f"Replicate API调用失败: {str(e)}")
139+
raise e
140+
141+
142+
NODE_CLASS_MAPPINGS = {
143+
"ReplicateRequstNode": ReplicateRequstNode,
144+
}
145+
146+
NODE_DISPLAY_NAME_MAPPINGS = {
147+
"ReplicateRequstNode": "Replicate Request",
148+
}
149+

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-utils-nodes"
33
description = "Nodes:LoadImageWithSwitch, ImageBatchOneOrMore, ModifyTextGender, GenderControlOutput, ImageCompositeMaskedWithSwitch, ImageCompositeMaskedOneByOne, ColorCorrectOfUtils, SplitMask, MaskFastGrow, CheckpointLoaderSimpleWithSwitch, ImageResizeTo8x, MatchImageRatioToPreset, MaskFromFaceModel, MaskCoverFourCorners, DetectorForNSFW, DeepfaceAnalyzeFaceAttributes, VolcanoOutpainting, VolcanoImageEdit, etc."
4-
version = "1.3.4"
4+
version = "1.3.5"
55
license = { file = "LICENSE" }
66
dependencies = []
77

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@ tf-keras==2.17.0
1313
google-generativeai>0.4.1
1414

1515
# volcano outpainting node
16-
volcengine
16+
volcengine
17+
18+
# replicate node
19+
replicate>=0.22.0

0 commit comments

Comments
 (0)