Skip to content

Commit 96e8ad8

Browse files
committed
add the ReplicateVideoRequestNode
1 parent 112af70 commit 96e8ad8

File tree

3 files changed

+140
-6
lines changed

3 files changed

+140
-6
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,6 @@ This node is designed to crop the image by the mask to a specific size.
175175
176176
## ReplicateRequstNode
177177
This node is designed to generate images using Replicate's model API. It supports various aspect ratios, LoRA weights, and provides flexible error handling options.
178+
179+
## ReplicateVideoRequestNode
180+
This node is designed to generate videos using Replicate's model API, with Wan2.2 as the default video model.

py/node_replicate.py

Lines changed: 135 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
import os
1212
import sys
1313
sys.path.append(".")
14-
14+
from comfy_api.latest._input_impl.video_types import VideoFromFile
15+
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
16+
from .utils import tensor2pil, np2tensor
1517
logger = logging.getLogger(__name__)
1618

1719
config_dir = os.path.join(folder_paths.base_path, "config")
@@ -169,7 +171,6 @@ def generate_image(self, prompt, seed, aspect_ratio, steps, guidance, go_fast, l
169171
# 处理输入图像
170172
if image is not None and len(image) > 0:
171173
# 将tensor转换为PIL图像,然后保存为临时文件
172-
from .utils import tensor2pil
173174
pil_image = tensor2pil(image[0]) # 取第一张图片
174175

175176
# 创建临时文件
@@ -211,7 +212,6 @@ def generate_image(self, prompt, seed, aspect_ratio, steps, guidance, go_fast, l
211212
image_array = image_array[:, :, :3]
212213
images.append(image_array)
213214

214-
from .utils import np2tensor
215215
image_tensor = np2tensor(images)
216216
urls_str = ",".join(urls)
217217

@@ -222,10 +222,141 @@ def generate_image(self, prompt, seed, aspect_ratio, steps, guidance, go_fast, l
222222
raise e
223223

224224

225+
class ReplicateVideoRequestNode:
226+
def __init__(self, api_key=None):
227+
from replicate.client import Client
228+
config = get_config()
229+
self.api_key = api_key or config.get("REPLICATE_API_TOKEN")
230+
if self.api_key is not None:
231+
self.configure_replicate()
232+
self.client = Client(timeout=300)
233+
234+
def configure_replicate(self):
235+
if self.api_key:
236+
os.environ["REPLICATE_API_TOKEN"] = self.api_key
237+
238+
@classmethod
239+
def INPUT_TYPES(cls):
240+
return {
241+
"required": {
242+
"prompt": ("STRING", {"default": "", "multiline": True}),
243+
"model": ("STRING", {"default": "wan-video/wan-2.2-i2v-fast"}),
244+
"num_frames": ("INT", {"default": 81, "min": 81, "max": 121}),
245+
"resolution": (["480p", "720p"], {"default": "720p"}),
246+
"frames_per_second": ("INT", {"default": 16, "min": 5, "max": 30, "step": 1}),
247+
},
248+
"optional": {
249+
"image": ("IMAGE",),
250+
"go_fast": ("BOOLEAN", {"default": True}),
251+
"sample_shift": ("FLOAT", {"default": 12.0, "min": 1.0, "max": 20.0, "step": 0.1}),
252+
"lora_weights_transformer": ("STRING", {"default": ""}),
253+
"lora_scale_transformer": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 3.0, "step": 0.01}),
254+
"lora_weights_transformer_2": ("STRING", {"default": ""}),
255+
"lora_scale_transformer_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 3.0, "step": 0.01}),
256+
"api_key": ("STRING", {"default": ""}),
257+
"timeout": ("INT", {"default": 300, "min": 1, "max": 3000}),
258+
}
259+
}
260+
261+
RETURN_TYPES = (IO.VIDEO, "INT", "INT", "FLOAT", "STRING")
262+
RETURN_NAMES = ("video", "width", "height", "fps", "url")
263+
FUNCTION = "generate_video"
264+
CATEGORY = "utils/video"
265+
266+
def generate_video(self, prompt, model, num_frames, resolution, frames_per_second, image=None,
267+
go_fast=True, sample_shift=12.0, lora_weights_transformer="",
268+
lora_scale_transformer=1.0, lora_weights_transformer_2="",
269+
lora_scale_transformer_2=1.0, api_key="", timeout=300):
270+
271+
if api_key.strip():
272+
self.api_key = api_key
273+
save_config({"REPLICATE_API_TOKEN": self.api_key})
274+
self.configure_replicate()
275+
276+
if not self.api_key:
277+
raise ValueError("API key not found in replicate_config.yml or node input")
278+
279+
try:
280+
input_params = {
281+
"prompt": prompt,
282+
"num_frames": num_frames,
283+
"resolution": resolution,
284+
"frames_per_second": frames_per_second,
285+
"go_fast": go_fast,
286+
"sample_shift": sample_shift,
287+
}
288+
289+
if image is not None and len(image) > 0:
290+
pil_image = tensor2pil(image[0])
291+
292+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
293+
pil_image.save(temp_file.name, format='PNG')
294+
temp_file_path = temp_file.name
295+
296+
input_params["image"] = open(temp_file_path, "rb")
297+
298+
if lora_weights_transformer.strip():
299+
input_params["lora_weights_transformer"] = lora_weights_transformer
300+
input_params["lora_scale_transformer"] = lora_scale_transformer
301+
302+
if lora_weights_transformer_2.strip():
303+
input_params["lora_weights_transformer_2"] = lora_weights_transformer_2
304+
input_params["lora_scale_transformer_2"] = lora_scale_transformer_2
305+
306+
logger.debug(f"调用Replicate API生成视频,参数: {input_params}")
307+
308+
runner = ComfyUIReplicateRun(timeout_seconds=timeout, check_interval=1.0)
309+
output = runner.run_with_interrupt_check(self.client, model, input=input_params)
310+
311+
if image is not None and len(image) > 0:
312+
try:
313+
input_params["image"].close()
314+
os.unlink(temp_file_path)
315+
except:
316+
pass
317+
318+
if not isinstance(output, list):
319+
output = [output]
320+
321+
video_url = output[0] if output else None
322+
if not video_url:
323+
raise Exception("No video URL returned from API")
324+
325+
logger.debug(f"生成的视频URL: {video_url}")
326+
327+
videos_dir = os.path.join(folder_paths.get_output_directory(), "videos_utils_nodes")
328+
if not os.path.exists(videos_dir):
329+
os.makedirs(videos_dir)
330+
331+
video_filename = f"replicate_video_{int(time.time())}.mp4"
332+
video_path = os.path.join(videos_dir, video_filename)
333+
334+
response = requests.get(video_url)
335+
response.raise_for_status()
336+
337+
with open(video_path, 'wb') as f:
338+
f.write(response.content)
339+
340+
logger.info(f"视频已保存到: {video_path}")
341+
342+
video_input = VideoFromFile(video_path)
343+
width, height = video_input.get_dimensions()
344+
fps = float(frames_per_second)
345+
346+
return (video_input, width, height, fps, video_url)
347+
348+
except Exception as e:
349+
logger.exception(f"Replicate视频生成失败: {str(e)}")
350+
raise e
351+
352+
225353
NODE_CLASS_MAPPINGS = {
226354
"ReplicateRequstNode": ReplicateRequstNode,
355+
"ReplicateVideoRequestNode": ReplicateVideoRequestNode,
227356
}
228357

229358
NODE_DISPLAY_NAME_MAPPINGS = {
230-
"ReplicateRequstNode": "Replicate Request",
359+
"ReplicateVideoRequestNode": "Replicate Video Request",
360+
"ReplicateRequstNode": "Replicate Image Request",
231361
}
362+

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-utils-nodes"
3-
description = "Nodes:LoadImageWithSwitch, ImageBatchOneOrMore, ModifyTextGender, GenderControlOutput, ImageCompositeMaskedWithSwitch, ImageCompositeMaskedOneByOne, ColorCorrectOfUtils, SplitMask, MaskFastGrow, CheckpointLoaderSimpleWithSwitch, ImageResizeTo8x, MatchImageRatioToPreset, MaskFromFaceModel, MaskCoverFourCorners, DetectorForNSFW, DeepfaceAnalyzeFaceAttributes, VolcanoOutpainting, VolcanoImageEdit, etc."
4-
version = "1.3.6"
3+
description = "Nodes:LoadImageWithSwitch, ImageBatchOneOrMore, GenderControlOutput, ImageCompositeMaskedWithSwitch, ImageCompositeMaskedOneByOne, ColorCorrectOfUtils, SplitMask, MaskFastGrow, CheckpointLoaderSimpleWithSwitch, ImageResizeTo8x, MatchImageRatioToPreset, MaskFromFaceModel, MaskCoverFourCorners, DetectorForNSFW, DeepfaceAnalyzeFaceAttributes, VolcanoOutpainting, VolcanoImageEdit, ReplicateRequstNode etc."
4+
version = "1.3.7"
55
license = { file = "LICENSE" }
66
dependencies = []
77

0 commit comments

Comments
 (0)