1717from ..utils .utils import video_has_audio
1818from ..utils .utils import resize_to_limit , prepare_paste_back , get_rotation_matrix , calc_lip_close_ratio , \
1919 calc_eye_close_ratio , transform_keypoint , concat_feat
20- from ..utils .crop import crop_image , parse_bbox_from_landmark , crop_image_by_bbox , paste_back
20+ from ..utils .crop import crop_image , parse_bbox_from_landmark , crop_image_by_bbox , paste_back , paste_back_pytorch
2121from src .utils import utils
2222import platform
23+ import torch
24+ from PIL import Image
2325
2426if platform .system ().lower () == 'windows' :
2527 FFMPEG = "third_party/ffmpeg-7.0.1-full_build/bin/ffmpeg.exe"
@@ -125,7 +127,12 @@ def run_local(self, driving_video_path, source_path, **kwargs):
125127 raise gr .Error (f"Error in processing source:{ source_path } 💥!" , duration = 5 )
126128
127129 vcap = cv2 .VideoCapture (driving_video_path )
128- fps = int (vcap .get (cv2 .CAP_PROP_FPS ))
130+ if self .is_source_video :
131+ duration , fps = utils .get_video_info (self .source_path )
132+ fps = int (fps )
133+ else :
134+ fps = int (vcap .get (cv2 .CAP_PROP_FPS ))
135+
129136 dframe = int (vcap .get (cv2 .CAP_PROP_FRAME_COUNT ))
130137 if self .is_source_video :
131138 max_frame = min (dframe , len (self .src_imgs ))
@@ -168,19 +175,38 @@ def run_local(self, driving_video_path, source_path, **kwargs):
168175
169176 if video_has_audio (driving_video_path ):
170177 vsave_crop_path_new = os .path .splitext (vsave_crop_path )[0 ] + "-audio.mp4"
171- subprocess .call (
172- [FFMPEG , "-i" , vsave_crop_path , "-i" , driving_video_path ,
173- "-b:v" , "10M" , "-c:v" ,
174- "libx264" , "-map" , "0:v" , "-map" , "1:a" ,
175- "-c:a" , "aac" ,
176- "-pix_fmt" , "yuv420p" , vsave_crop_path_new , "-y" , "-shortest" ])
177178 vsave_org_path_new = os .path .splitext (vsave_org_path )[0 ] + "-audio.mp4"
178- subprocess .call (
179- [FFMPEG , "-i" , vsave_org_path , "-i" , driving_video_path ,
180- "-b:v" , "10M" , "-c:v" ,
181- "libx264" , "-map" , "0:v" , "-map" , "1:a" ,
182- "-c:a" , "aac" ,
183- "-pix_fmt" , "yuv420p" , vsave_org_path_new , "-y" , "-shortest" ])
179+ if self .is_source_video :
180+ duration , fps = utils .get_video_info (vsave_crop_path )
181+ subprocess .call (
182+ [FFMPEG , "-i" , vsave_crop_path , "-i" , driving_video_path ,
183+ "-b:v" , "10M" , "-c:v" , "libx264" , "-map" , "0:v" , "-map" , "1:a" ,
184+ "-c:a" , "aac" , "-pix_fmt" , "yuv420p" ,
185+ "-shortest" , # 以最短的流为基准
186+ "-t" , str (duration ), # 设置时长
187+ "-r" , str (fps ), # 设置帧率
188+ vsave_crop_path_new , "-y" ])
189+ subprocess .call (
190+ [FFMPEG , "-i" , vsave_org_path , "-i" , driving_video_path ,
191+ "-b:v" , "10M" , "-c:v" , "libx264" , "-map" , "0:v" , "-map" , "1:a" ,
192+ "-c:a" , "aac" , "-pix_fmt" , "yuv420p" ,
193+ "-shortest" , # 以最短的流为基准
194+ "-t" , str (duration ), # 设置时长
195+ "-r" , str (fps ), # 设置帧率
196+ vsave_org_path_new , "-y" ])
197+ else :
198+ subprocess .call (
199+ [FFMPEG , "-i" , vsave_crop_path , "-i" , driving_video_path ,
200+ "-b:v" , "10M" , "-c:v" ,
201+ "libx264" , "-map" , "0:v" , "-map" , "1:a" ,
202+ "-c:a" , "aac" ,
203+ "-pix_fmt" , "yuv420p" , vsave_crop_path_new , "-y" , "-shortest" ])
204+ subprocess .call (
205+ [FFMPEG , "-i" , vsave_org_path , "-i" , driving_video_path ,
206+ "-b:v" , "10M" , "-c:v" ,
207+ "libx264" , "-map" , "0:v" , "-map" , "1:a" ,
208+ "-c:a" , "aac" ,
209+ "-pix_fmt" , "yuv420p" , vsave_org_path_new , "-y" , "-shortest" ])
184210
185211 return vsave_org_path_new , vsave_crop_path_new , total_time
186212 else :
@@ -207,9 +233,10 @@ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_im
207233 x_d_new = x_s_user + eyes_delta .reshape (- 1 , num_kp , 3 ) + lip_delta .reshape (- 1 , num_kp , 3 )
208234 # D(W(f_s; x_s, x′_d))
209235 out = self .model_dict ["warping_spade" ].predict (f_s_user , x_s_user , x_d_new )
210- out_to_ori_blend = paste_back (out , crop_M_c2o , img_rgb , mask_ori )
236+ img_rgb = torch .from_numpy (img_rgb ).to (self .device )
237+ out_to_ori_blend = paste_back_pytorch (out , crop_M_c2o , img_rgb , mask_ori )
211238 gr .Info ("Run successfully!" , duration = 2 )
212- return out , out_to_ori_blend
239+ return out . to ( dtype = torch . uint8 ). cpu (). numpy () , out_to_ori_blend . to ( dtype = torch . uint8 ). cpu (). numpy ()
213240
214241 def prepare_retargeting (self , input_image , flag_do_crop = True ):
215242 """ for single image retargeting
@@ -221,16 +248,18 @@ def prepare_retargeting(self, input_image, flag_do_crop=True):
221248 self .cfg .infer_params .source_division )
222249 img_rgb = cv2 .cvtColor (img_bgr , cv2 .COLOR_BGR2RGB )
223250
224- src_faces = self .model_dict ["face_analysis" ].predict (img_bgr )
251+ if self .is_animal :
252+ raise gr .Error ("Animal Model Not Supported in Face Retarget 💥!" , duration = 5 )
253+ else :
254+ src_faces = self .model_dict ["face_analysis" ].predict (img_bgr )
225255
226256 if len (src_faces ) == 0 :
227257 raise gr .Error ("No face detect in image 💥!" , duration = 5 )
228258 src_faces = src_faces [:1 ]
229259 crop_infos = []
230260 for i in range (len (src_faces )):
231261 # NOTE: temporarily only pick the first face, to support multiple face in the future
232- src_face = src_faces [i ]
233- lmk = src_face .landmark # this is the 106 landmarks from insightface
262+ lmk = src_faces [i ]
234263 # crop the face
235264 ret_dct = crop_image (
236265 img_rgb , # ndarray
@@ -240,8 +269,10 @@ def prepare_retargeting(self, input_image, flag_do_crop=True):
240269 vx_ratio = self .cfg .crop_params .src_vx_ratio ,
241270 vy_ratio = self .cfg .crop_params .src_vy_ratio ,
242271 )
272+
243273 lmk = self .model_dict ["landmark" ].predict (img_rgb , lmk )
244274 ret_dct ["lmk_crop" ] = lmk
275+ ret_dct ["lmk_crop_256x256" ] = ret_dct ["lmk_crop" ] * 256 / self .cfg .crop_params .src_dsize
245276
246277 # update a 256x256 version for network input
247278 ret_dct ["img_crop_256x256" ] = cv2 .resize (
@@ -270,9 +301,10 @@ def prepare_retargeting(self, input_image, flag_do_crop=True):
270301 x_s_user = transform_keypoint (pitch , yaw , roll , t , exp , scale , kp )
271302 source_lmk_user = crop_info ['lmk_crop' ]
272303 crop_M_c2o = crop_info ['M_c2o' ]
273-
304+ crop_M_c2o = torch . from_numpy ( crop_M_c2o ). to ( self . device )
274305 mask_ori = prepare_paste_back (self .mask_crop , crop_info ['M_c2o' ],
275306 dsize = (img_rgb .shape [1 ], img_rgb .shape [0 ]))
307+ mask_ori = torch .from_numpy (mask_ori ).to (self .device ).float ()
276308 return f_s_user , x_s_user , source_lmk_user , crop_M_c2o , mask_ori , img_rgb
277309 else :
278310 # when press the clear button, go here
0 commit comments