Skip to content

Commit 8bc2dcb

Browse files
authored
Merge pull request #70 from warmshao/anim
修复v2v视频长度和retarget模块的问题
2 parents 981057e + 4bea43e commit 8bc2dcb

File tree

4 files changed

+73
-24
lines changed

4 files changed

+73
-24
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
**New features:**
77
* Achieved real-time running of LivePortrait on RTX 3090 GPU using TensorRT, reaching speeds of 30+ FPS. This is the speed for rendering a single frame, including pre- and post-processing, not just the model inference speed.
88
* Implemented conversion of LivePortrait model to Onnx model, achieving inference speed of about 70ms/frame (~12 FPS) using onnxruntime-gpu on RTX 3090, facilitating cross-platform deployment.
9-
* Seamless support for native gradio app, with several times faster speed and support for simultaneous inference on multiple faces. Some results can be seen here: [pr105](https://github.com/KwaiVGI/LivePortrait/pull/105)
10-
* Refactored code structure, no longer dependent on pytorch, all models use onnx or tensorrt for inference.
9+
* Seamless support for native gradio app, with several times faster speed and support for simultaneous inference on multiple faces and Animal Model.
1110

1211
**If you find this project useful, please give it a star ✨✨**
1312

README_CN.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
**新增功能:**
77
* 通过TensorRT实现在RTX 3090显卡上**实时**运行LivePortrait,速度达到 30+ FPS. 这个速度是实测渲染出一帧的速度,而不仅仅是模型的推理时间。
88
* 实现将LivePortrait模型转为Onnx模型,使用onnxruntime-gpu在RTX 3090上的推理速度约为 70ms/帧(~12 FPS),方便跨平台的部署。
9-
* 无缝支持原生的gradio app, 速度快了好几倍,同时支持对多张人脸的同时推理,一些效果可以看:[pr105](https://github.com/KwaiVGI/LivePortrait/pull/105)
10-
* 对代码结构进行了重构,不再依赖pytorch,所有的模型用onnx或tensorrt推理。
9+
* 无缝支持原生的gradio app, 速度快了好几倍,支持多张人脸、Animal模型。
1110

1211
**如果你觉得这个项目有用,帮我点个star吧✨✨**
1312

src/pipelines/gradio_live_portrait_pipeline.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from ..utils.utils import video_has_audio
1818
from ..utils.utils import resize_to_limit, prepare_paste_back, get_rotation_matrix, calc_lip_close_ratio, \
1919
calc_eye_close_ratio, transform_keypoint, concat_feat
20-
from ..utils.crop import crop_image, parse_bbox_from_landmark, crop_image_by_bbox, paste_back
20+
from ..utils.crop import crop_image, parse_bbox_from_landmark, crop_image_by_bbox, paste_back, paste_back_pytorch
2121
from src.utils import utils
2222
import platform
23+
import torch
24+
from PIL import Image
2325

2426
if platform.system().lower() == 'windows':
2527
FFMPEG = "third_party/ffmpeg-7.0.1-full_build/bin/ffmpeg.exe"
@@ -125,7 +127,12 @@ def run_local(self, driving_video_path, source_path, **kwargs):
125127
raise gr.Error(f"Error in processing source:{source_path} 💥!", duration=5)
126128

127129
vcap = cv2.VideoCapture(driving_video_path)
128-
fps = int(vcap.get(cv2.CAP_PROP_FPS))
130+
if self.is_source_video:
131+
duration, fps = utils.get_video_info(self.source_path)
132+
fps = int(fps)
133+
else:
134+
fps = int(vcap.get(cv2.CAP_PROP_FPS))
135+
129136
dframe = int(vcap.get(cv2.CAP_PROP_FRAME_COUNT))
130137
if self.is_source_video:
131138
max_frame = min(dframe, len(self.src_imgs))
@@ -168,19 +175,38 @@ def run_local(self, driving_video_path, source_path, **kwargs):
168175

169176
if video_has_audio(driving_video_path):
170177
vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4"
171-
subprocess.call(
172-
[FFMPEG, "-i", vsave_crop_path, "-i", driving_video_path,
173-
"-b:v", "10M", "-c:v",
174-
"libx264", "-map", "0:v", "-map", "1:a",
175-
"-c:a", "aac",
176-
"-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"])
177178
vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4"
178-
subprocess.call(
179-
[FFMPEG, "-i", vsave_org_path, "-i", driving_video_path,
180-
"-b:v", "10M", "-c:v",
181-
"libx264", "-map", "0:v", "-map", "1:a",
182-
"-c:a", "aac",
183-
"-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"])
179+
if self.is_source_video:
180+
duration, fps = utils.get_video_info(vsave_crop_path)
181+
subprocess.call(
182+
[FFMPEG, "-i", vsave_crop_path, "-i", driving_video_path,
183+
"-b:v", "10M", "-c:v", "libx264", "-map", "0:v", "-map", "1:a",
184+
"-c:a", "aac", "-pix_fmt", "yuv420p",
185+
"-shortest", # 以最短的流为基准
186+
"-t", str(duration), # 设置时长
187+
"-r", str(fps), # 设置帧率
188+
vsave_crop_path_new, "-y"])
189+
subprocess.call(
190+
[FFMPEG, "-i", vsave_org_path, "-i", driving_video_path,
191+
"-b:v", "10M", "-c:v", "libx264", "-map", "0:v", "-map", "1:a",
192+
"-c:a", "aac", "-pix_fmt", "yuv420p",
193+
"-shortest", # 以最短的流为基准
194+
"-t", str(duration), # 设置时长
195+
"-r", str(fps), # 设置帧率
196+
vsave_org_path_new, "-y"])
197+
else:
198+
subprocess.call(
199+
[FFMPEG, "-i", vsave_crop_path, "-i", driving_video_path,
200+
"-b:v", "10M", "-c:v",
201+
"libx264", "-map", "0:v", "-map", "1:a",
202+
"-c:a", "aac",
203+
"-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"])
204+
subprocess.call(
205+
[FFMPEG, "-i", vsave_org_path, "-i", driving_video_path,
206+
"-b:v", "10M", "-c:v",
207+
"libx264", "-map", "0:v", "-map", "1:a",
208+
"-c:a", "aac",
209+
"-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"])
184210

185211
return vsave_org_path_new, vsave_crop_path_new, total_time
186212
else:
@@ -207,9 +233,10 @@ def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_im
207233
x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3)
208234
# D(W(f_s; x_s, x′_d))
209235
out = self.model_dict["warping_spade"].predict(f_s_user, x_s_user, x_d_new)
210-
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
236+
img_rgb = torch.from_numpy(img_rgb).to(self.device)
237+
out_to_ori_blend = paste_back_pytorch(out, crop_M_c2o, img_rgb, mask_ori)
211238
gr.Info("Run successfully!", duration=2)
212-
return out, out_to_ori_blend
239+
return out.to(dtype=torch.uint8).cpu().numpy(), out_to_ori_blend.to(dtype=torch.uint8).cpu().numpy()
213240

214241
def prepare_retargeting(self, input_image, flag_do_crop=True):
215242
""" for single image retargeting
@@ -221,16 +248,18 @@ def prepare_retargeting(self, input_image, flag_do_crop=True):
221248
self.cfg.infer_params.source_division)
222249
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
223250

224-
src_faces = self.model_dict["face_analysis"].predict(img_bgr)
251+
if self.is_animal:
252+
raise gr.Error("Animal Model Not Supported in Face Retarget 💥!", duration=5)
253+
else:
254+
src_faces = self.model_dict["face_analysis"].predict(img_bgr)
225255

226256
if len(src_faces) == 0:
227257
raise gr.Error("No face detect in image 💥!", duration=5)
228258
src_faces = src_faces[:1]
229259
crop_infos = []
230260
for i in range(len(src_faces)):
231261
# NOTE: temporarily only pick the first face, to support multiple face in the future
232-
src_face = src_faces[i]
233-
lmk = src_face.landmark # this is the 106 landmarks from insightface
262+
lmk = src_faces[i]
234263
# crop the face
235264
ret_dct = crop_image(
236265
img_rgb, # ndarray
@@ -240,8 +269,10 @@ def prepare_retargeting(self, input_image, flag_do_crop=True):
240269
vx_ratio=self.cfg.crop_params.src_vx_ratio,
241270
vy_ratio=self.cfg.crop_params.src_vy_ratio,
242271
)
272+
243273
lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
244274
ret_dct["lmk_crop"] = lmk
275+
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / self.cfg.crop_params.src_dsize
245276

246277
# update a 256x256 version for network input
247278
ret_dct["img_crop_256x256"] = cv2.resize(
@@ -270,9 +301,10 @@ def prepare_retargeting(self, input_image, flag_do_crop=True):
270301
x_s_user = transform_keypoint(pitch, yaw, roll, t, exp, scale, kp)
271302
source_lmk_user = crop_info['lmk_crop']
272303
crop_M_c2o = crop_info['M_c2o']
273-
304+
crop_M_c2o = torch.from_numpy(crop_M_c2o).to(self.device)
274305
mask_ori = prepare_paste_back(self.mask_crop, crop_info['M_c2o'],
275306
dsize=(img_rgb.shape[1], img_rgb.shape[0]))
307+
mask_ori = torch.from_numpy(mask_ori).to(self.device).float()
276308
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
277309
else:
278310
# when press the clear button, go here

src/utils/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,25 @@ def video_has_audio(video_file):
1616
return False
1717

1818

19+
def get_video_info(video_path):
20+
# 使用 ffmpeg.probe 获取视频信息
21+
probe = ffmpeg.probe(video_path)
22+
video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
23+
24+
if not video_streams:
25+
raise ValueError("No video stream found")
26+
27+
# 获取视频时长
28+
duration = float(probe['format']['duration'])
29+
30+
# 获取帧率 (r_frame_rate),通常是一个分数字符串,如 "30000/1001"
31+
fps_string = video_streams[0]['r_frame_rate']
32+
numerator, denominator = map(int, fps_string.split('/'))
33+
fps = numerator / denominator
34+
35+
return duration, fps
36+
37+
1938
def resize_to_limit(img: np.ndarray, max_dim=1280, division=2):
2039
"""
2140
ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n.

0 commit comments

Comments
 (0)