1111import os
1212import sys
1313sys .path .append ("." )
14-
14+ from comfy_api .latest ._input_impl .video_types import VideoFromFile
15+ from comfy .comfy_types import IO , FileLocator , ComfyNodeABC
16+ from .utils import tensor2pil , np2tensor
1517logger = logging .getLogger (__name__ )
1618
1719config_dir = os .path .join (folder_paths .base_path , "config" )
@@ -169,7 +171,6 @@ def generate_image(self, prompt, seed, aspect_ratio, steps, guidance, go_fast, l
169171 # 处理输入图像
170172 if image is not None and len (image ) > 0 :
171173 # 将tensor转换为PIL图像,然后保存为临时文件
172- from .utils import tensor2pil
173174 pil_image = tensor2pil (image [0 ]) # 取第一张图片
174175
175176 # 创建临时文件
@@ -211,7 +212,6 @@ def generate_image(self, prompt, seed, aspect_ratio, steps, guidance, go_fast, l
211212 image_array = image_array [:, :, :3 ]
212213 images .append (image_array )
213214
214- from .utils import np2tensor
215215 image_tensor = np2tensor (images )
216216 urls_str = "," .join (urls )
217217
@@ -222,10 +222,141 @@ def generate_image(self, prompt, seed, aspect_ratio, steps, guidance, go_fast, l
222222 raise e
223223
224224
225+ class ReplicateVideoRequestNode :
226+ def __init__ (self , api_key = None ):
227+ from replicate .client import Client
228+ config = get_config ()
229+ self .api_key = api_key or config .get ("REPLICATE_API_TOKEN" )
230+ if self .api_key is not None :
231+ self .configure_replicate ()
232+ self .client = Client (timeout = 300 )
233+
234+ def configure_replicate (self ):
235+ if self .api_key :
236+ os .environ ["REPLICATE_API_TOKEN" ] = self .api_key
237+
238+ @classmethod
239+ def INPUT_TYPES (cls ):
240+ return {
241+ "required" : {
242+ "prompt" : ("STRING" , {"default" : "" , "multiline" : True }),
243+ "model" : ("STRING" , {"default" : "wan-video/wan-2.2-i2v-fast" }),
244+ "num_frames" : ("INT" , {"default" : 81 , "min" : 81 , "max" : 121 }),
245+ "resolution" : (["480p" , "720p" ], {"default" : "720p" }),
246+ "frames_per_second" : ("INT" , {"default" : 16 , "min" : 5 , "max" : 30 , "step" : 1 }),
247+ },
248+ "optional" : {
249+ "image" : ("IMAGE" ,),
250+ "go_fast" : ("BOOLEAN" , {"default" : True }),
251+ "sample_shift" : ("FLOAT" , {"default" : 12.0 , "min" : 1.0 , "max" : 20.0 , "step" : 0.1 }),
252+ "lora_weights_transformer" : ("STRING" , {"default" : "" }),
253+ "lora_scale_transformer" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 3.0 , "step" : 0.01 }),
254+ "lora_weights_transformer_2" : ("STRING" , {"default" : "" }),
255+ "lora_scale_transformer_2" : ("FLOAT" , {"default" : 1.0 , "min" : 0.0 , "max" : 3.0 , "step" : 0.01 }),
256+ "api_key" : ("STRING" , {"default" : "" }),
257+ "timeout" : ("INT" , {"default" : 300 , "min" : 1 , "max" : 3000 }),
258+ }
259+ }
260+
261+ RETURN_TYPES = (IO .VIDEO , "INT" , "INT" , "FLOAT" , "STRING" )
262+ RETURN_NAMES = ("video" , "width" , "height" , "fps" , "url" )
263+ FUNCTION = "generate_video"
264+ CATEGORY = "utils/video"
265+
266+ def generate_video (self , prompt , model , num_frames , resolution , frames_per_second , image = None ,
267+ go_fast = True , sample_shift = 12.0 , lora_weights_transformer = "" ,
268+ lora_scale_transformer = 1.0 , lora_weights_transformer_2 = "" ,
269+ lora_scale_transformer_2 = 1.0 , api_key = "" , timeout = 300 ):
270+
271+ if api_key .strip ():
272+ self .api_key = api_key
273+ save_config ({"REPLICATE_API_TOKEN" : self .api_key })
274+ self .configure_replicate ()
275+
276+ if not self .api_key :
277+ raise ValueError ("API key not found in replicate_config.yml or node input" )
278+
279+ try :
280+ input_params = {
281+ "prompt" : prompt ,
282+ "num_frames" : num_frames ,
283+ "resolution" : resolution ,
284+ "frames_per_second" : frames_per_second ,
285+ "go_fast" : go_fast ,
286+ "sample_shift" : sample_shift ,
287+ }
288+
289+ if image is not None and len (image ) > 0 :
290+ pil_image = tensor2pil (image [0 ])
291+
292+ with tempfile .NamedTemporaryFile (suffix = '.png' , delete = False ) as temp_file :
293+ pil_image .save (temp_file .name , format = 'PNG' )
294+ temp_file_path = temp_file .name
295+
296+ input_params ["image" ] = open (temp_file_path , "rb" )
297+
298+ if lora_weights_transformer .strip ():
299+ input_params ["lora_weights_transformer" ] = lora_weights_transformer
300+ input_params ["lora_scale_transformer" ] = lora_scale_transformer
301+
302+ if lora_weights_transformer_2 .strip ():
303+ input_params ["lora_weights_transformer_2" ] = lora_weights_transformer_2
304+ input_params ["lora_scale_transformer_2" ] = lora_scale_transformer_2
305+
306+ logger .debug (f"调用Replicate API生成视频,参数: { input_params } " )
307+
308+ runner = ComfyUIReplicateRun (timeout_seconds = timeout , check_interval = 1.0 )
309+ output = runner .run_with_interrupt_check (self .client , model , input = input_params )
310+
311+ if image is not None and len (image ) > 0 :
312+ try :
313+ input_params ["image" ].close ()
314+ os .unlink (temp_file_path )
315+ except :
316+ pass
317+
318+ if not isinstance (output , list ):
319+ output = [output ]
320+
321+ video_url = output [0 ] if output else None
322+ if not video_url :
323+ raise Exception ("No video URL returned from API" )
324+
325+ logger .debug (f"生成的视频URL: { video_url } " )
326+
327+ videos_dir = os .path .join (folder_paths .get_output_directory (), "videos_utils_nodes" )
328+ if not os .path .exists (videos_dir ):
329+ os .makedirs (videos_dir )
330+
331+ video_filename = f"replicate_video_{ int (time .time ())} .mp4"
332+ video_path = os .path .join (videos_dir , video_filename )
333+
334+ response = requests .get (video_url )
335+ response .raise_for_status ()
336+
337+ with open (video_path , 'wb' ) as f :
338+ f .write (response .content )
339+
340+ logger .info (f"视频已保存到: { video_path } " )
341+
342+ video_input = VideoFromFile (video_path )
343+ width , height = video_input .get_dimensions ()
344+ fps = float (frames_per_second )
345+
346+ return (video_input , width , height , fps , video_url )
347+
348+ except Exception as e :
349+ logger .exception (f"Replicate视频生成失败: { str (e )} " )
350+ raise e
351+
352+
225353NODE_CLASS_MAPPINGS = {
226354 "ReplicateRequstNode" : ReplicateRequstNode ,
355+ "ReplicateVideoRequestNode" : ReplicateVideoRequestNode ,
227356}
228357
229358NODE_DISPLAY_NAME_MAPPINGS = {
230- "ReplicateRequstNode" : "Replicate Request" ,
359+ "ReplicateVideoRequestNode" : "Replicate Video Request" ,
360+ "ReplicateRequstNode" : "Replicate Image Request" ,
231361}
362+
0 commit comments