1+ import logging
12import os
2- import shutil
3+ from contextlib import nullcontext
34
45import matplotlib
6+
57matplotlib .use ('Agg' )
6- import sys
78import yaml
89from argparse import ArgumentParser
910from tqdm import tqdm
1718from modules .keypoint_detector import KPDetector
1819from modules .dense_motion import DenseMotionNetwork
1920from modules .avd_network import AVDNetwork
21+ from utils import VideoReader , VideoWriter
2022
21- if sys .version_info [0 ] < 3 :
22- raise Exception ("You must use Python 3 or higher. Recommended version is Python 3.9" )
23+ logger = logging .getLogger ("TPSMM" )
2324
24- def relative_kp (kp_source , kp_driving , kp_driving_initial ):
2525
26+ def relative_kp (kp_source , kp_driving , kp_driving_initial ):
2627 source_area = ConvexHull (kp_source ['fg_kp' ][0 ].data .cpu ().numpy ()).volume
2728 driving_area = ConvexHull (kp_driving_initial ['fg_kp' ][0 ].data .cpu ().numpy ()).volume
2829 adapt_movement_scale = np .sqrt (source_area ) / np .sqrt (driving_area )
@@ -35,12 +36,13 @@ def relative_kp(kp_source, kp_driving, kp_driving_initial):
3536
3637 return kp_new
3738
39+
3840def load_checkpoints (config_path , checkpoint_path , device ):
3941 with open (config_path ) as f :
4042 config = yaml .full_load (f )
4143
4244 inpainting = InpaintingNetwork (** config ['model_params' ]['generator_params' ],
43- ** config ['model_params' ]['common_params' ])
45+ ** config ['model_params' ]['common_params' ])
4446 kp_detector = KPDetector (** config ['model_params' ]['common_params' ])
4547 dense_motion_network = DenseMotionNetwork (** config ['model_params' ]['common_params' ],
4648 ** config ['model_params' ]['dense_motion_params' ])
@@ -50,30 +52,36 @@ def load_checkpoints(config_path, checkpoint_path, device):
5052 dense_motion_network .to (device )
5153 inpainting .to (device )
5254 avd_network .to (device )
53-
55+
5456 checkpoint = torch .load (checkpoint_path , map_location = device )
55-
57+
5658 inpainting .load_state_dict (checkpoint ['inpainting_network' ])
5759 kp_detector .load_state_dict (checkpoint ['kp_detector' ])
5860 dense_motion_network .load_state_dict (checkpoint ['dense_motion_network' ])
5961 if 'avd_network' in checkpoint :
6062 avd_network .load_state_dict (checkpoint ['avd_network' ])
61-
63+
6264 inpainting .eval ()
6365 kp_detector .eval ()
6466 dense_motion_network .eval ()
6567 avd_network .eval ()
66-
68+
6769 return inpainting , kp_detector , dense_motion_network , avd_network
6870
6971
70- def make_animation (source_image , driving_video_generator , inpainting_network , kp_detector , dense_motion_network , avd_network , device :torch .device , mode = 'relative' ):
72+ def make_animation (source_image , driving_video_generator , inpainting_network , kp_detector , dense_motion_network ,
73+ avd_network , device : torch .device , mode = 'relative' , autocast_dtype = torch .float16 , autocast = False ):
7174 assert mode in ['standard' , 'relative' , 'avd' ]
7275 with torch .no_grad ():
73- with torch .autocast (device_type = str (device ), dtype = torch .float16 ):
76+
77+ if autocast :
78+ autocast_context = torch .autocast (device_type = str (device ), dtype = autocast_dtype )
79+ else :
80+ autocast_context = nullcontext ()
81+
82+ with autocast_context :
7483 source = torch .tensor (source_image [np .newaxis ].astype (np .float32 )).permute (0 , 3 , 1 , 2 )
7584 source = source .to (device )
76- #driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
7785 kp_source = kp_detector (source )
7886
7987 first_frame = True
@@ -88,18 +96,19 @@ def make_animation(source_image, driving_video_generator, inpainting_network, kp
8896 kp_driving = kp_detector (driving_frame )
8997 if mode == 'standard' :
9098 kp_norm = kp_driving
91- elif mode == 'relative' :
99+ elif mode == 'relative' :
92100 kp_norm = relative_kp (kp_source = kp_source , kp_driving = kp_driving ,
93- kp_driving_initial = kp_driving_initial )
101+ kp_driving_initial = kp_driving_initial )
94102 elif mode == 'avd' :
95103 kp_norm = avd_network (kp_source , kp_driving )
96104 dense_motion = dense_motion_network (source_image = source , kp_driving = kp_norm ,
97- kp_source = kp_source , bg_param = None ,
98- dropout_flag = False )
105+ kp_source = kp_source , bg_param = None ,
106+ dropout_flag = False )
99107 out = inpainting_network (source , dense_motion )
100108
101109 yield np .transpose (out ['prediction' ].data .cpu ().numpy (), [0 , 2 , 3 , 1 ])[0 ]
102110
111+
103112def find_best_frame (source , driving , cpu ):
104113 import face_alignment
105114
@@ -110,11 +119,11 @@ def normalize_kp(kp):
110119 kp [:, :2 ] = kp [:, :2 ] / area
111120 return kp
112121
113- fa = face_alignment .FaceAlignment (face_alignment .LandmarksType ._2D , flip_input = True ,
114- device = 'cpu' if cpu else 'cuda' )
122+ fa = face_alignment .FaceAlignment (face_alignment .LandmarksType .TWO_D , flip_input = True ,
123+ device = 'cpu' if cpu else 'cuda' )
115124 kp_source = fa .get_landmarks (255 * source )[0 ]
116125 kp_source = normalize_kp (kp_source )
117- norm = float ('inf' )
126+ norm = float ('inf' )
118127 frame_num = 0
119128 for i , image in tqdm (enumerate (driving )):
120129 try :
@@ -127,24 +136,30 @@ def normalize_kp(kp):
127136 except :
128137 pass
129138 return frame_num
139+
140+
130141def read_and_resize_frames (video_path , img_shape ):
131- reader = imageio . get_reader (video_path )
142+ reader = VideoReader (video_path )
132143 for frame in reader :
133144 resized_frame = resize (frame , img_shape )[..., :3 ]
134145 yield resized_frame
146+
135147 reader .close ()
136148
149+
137150def read_and_resize_frames_forward (video_path , img_shape , start_frame ):
138- reader = imageio . get_reader (video_path )
151+ reader = VideoReader (video_path )
139152 for idx , frame in enumerate (reader ):
140153 if idx < start_frame :
141154 continue
142155 resized_frame = resize (frame , img_shape )[..., :3 ]
143156 yield resized_frame
144157 reader .close ()
145158
159+
146160def read_and_resize_frames_backward (video_path , img_shape , end_frame ):
147- reader = imageio .get_reader (video_path )
161+ reader = VideoReader (video_path )
162+
148163 frames = []
149164 for idx , frame in enumerate (reader ):
150165 if idx > end_frame :
@@ -154,72 +169,102 @@ def read_and_resize_frames_backward(video_path, img_shape, end_frame):
154169 reader .close ()
155170 return reversed (frames )
156171
172+
157173if __name__ == "__main__" :
158174 parser = ArgumentParser ()
159175 parser .add_argument ("--config" , required = True , help = "path to config" )
160176 parser .add_argument ("--checkpoint" , default = 'checkpoints/vox.pth.tar' , help = "path to checkpoint to restore" )
161177
162178 parser .add_argument ("--source_image" , default = './assets/source.png' , help = "path to source image" )
163- parser .add_argument ("--driving_video" , default = './assets/driving.mp4' , help = "path to driving video" )
164- parser .add_argument ("--result_video" , default = './result.mp4' , help = "path to output" )
165-
179+ parser .add_argument ("--driving_video" , default = './assets/driving.mp4' , help = "path to driving video or folder of images " )
180+ parser .add_argument ("--result_video" , default = './result.mp4' , help = "path to output. Can be file name or folder. " )
181+
166182 parser .add_argument ("--img_shape" , default = "256,256" , type = lambda x : list (map (int , x .split (',' ))),
167183 help = 'Shape of image, that the model was trained on.' )
168-
169- parser .add_argument ("--mode" , default = 'relative' , choices = ['standard' , 'relative' , 'avd' ], help = "Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result" )
170-
171- parser .add_argument ("--find_best_frame" , dest = "find_best_frame" , action = "store_true" ,
184+
185+ parser .add_argument ("--mode" , default = 'relative' , choices = ['standard' , 'relative' , 'avd' ],
186+ help = "Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result" )
187+
188+ parser .add_argument ("--find_best_frame" , dest = "find_best_frame" , action = "store_true" ,
172189 help = "Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)" )
173190
174191 parser .add_argument ("--cpu" , dest = "cpu" , action = "store_true" , help = "cpu mode." )
192+ parser .add_argument ("--autocast" , dest = "autocast" , action = "store_true" , help = "Autocast mode." )
193+
175194
176195 opt = parser .parse_args ()
177196
178197 source_image = imageio .imread (opt .source_image )
179- reader = imageio .get_reader (opt .driving_video )
180- fps = reader .get_meta_data ()['fps' ]
181- reader .close ()
182-
183- if opt .cpu :
198+
199+ if os .path .isdir (opt .driving_video ):
200+ fps = 30
201+ length = len (os .listdir (opt .driving_video ))
202+ elif "%" in opt .driving_video :
203+ fps = 30
204+ length = len (os .path .dirname (opt .driving_video ))
205+ else :
206+ reader = imageio .get_reader (opt .driving_video , mode = 'I' )
207+
208+ fps = reader .get_meta_data ().get ('fps' , 30 )
209+ length = int (reader .get_meta_data ().get ('duration' , 0 )) * int (fps )
210+
211+ reader .close ()
212+
213+ if opt .cpu and opt .autocast :
214+ autocast_dtype = torch .bfloat16
215+ else :
216+ autocast_dtype = torch .float16
217+
218+ if opt .cpu or torch .cuda .device_count () == 0 :
184219 device = torch .device ('cpu' )
185220 else :
186221 device = torch .device ('cuda' )
187-
222+
188223 source_image = resize (source_image , opt .img_shape )[..., :3 ]
189- inpainting , kp_detector , dense_motion_network , avd_network = load_checkpoints (config_path = opt .config , checkpoint_path = opt .checkpoint , device = device )
224+ inpainting , kp_detector , dense_motion_network , avd_network = load_checkpoints (config_path = opt .config ,
225+ checkpoint_path = opt .checkpoint ,
226+ device = device )
227+
190228
191229 def reversed_generator (generator ):
192230 frames = list (generator )
193231 return reversed (frames )
194232
233+
195234 def append_frame_to_writer (frame , writer ):
196235 writer .append_data (img_as_ubyte (frame ))
197236
198-
237+ writer = VideoWriter ( opt . result_video , mode = 'I' , fps = fps )
199238 if opt .find_best_frame :
200239 driving_video_generator = read_and_resize_frames (opt .driving_video , opt .img_shape )
201240 i = find_best_frame (source_image , driving_video_generator , opt .cpu )
202- print ("Best frame:" , i )
203241 driving_forward = read_and_resize_frames_forward (opt .driving_video , opt .img_shape , i )
204242 driving_backward = read_and_resize_frames_backward (opt .driving_video , opt .img_shape , i )
205243
206- with imageio . get_writer ( opt . result_video , mode = 'I' , fps = fps ) as writer :
244+ with writer :
207245 # Generate and append frames for the reversed backward animation
208246 backward_animation = make_animation (source_image , driving_backward , inpainting , kp_detector ,
209- dense_motion_network , avd_network , device = device , mode = opt .mode )
247+ dense_motion_network , avd_network , device = device , mode = opt .mode ,
248+ autocast_dtype = autocast_dtype , autocast = opt .autocast )
249+
210250 for frame in reversed_generator (backward_animation ):
211251 append_frame_to_writer (frame , writer )
212252
213253 # Generate and append frames for forward animation, skipping the first frame
214- for idx , frame in enumerate (
254+ for idx , frame in tqdm ( enumerate (
215255 make_animation (source_image , driving_forward , inpainting , kp_detector , dense_motion_network ,
216- avd_network , device = device , mode = opt .mode )):
256+ avd_network , device = device , mode = opt .mode , autocast_dtype = autocast_dtype ,
257+ autocast = opt .autocast
258+ )), total = length ):
217259 if idx == 0 :
218260 continue
219261 append_frame_to_writer (frame , writer )
220262 else :
221- with imageio . get_writer ( opt . result_video , mode = 'I' , fps = fps ) as writer :
263+ with writer :
222264 driving_video_generator = read_and_resize_frames (opt .driving_video , opt .img_shape )
223- for frame in make_animation (source_image , driving_video_generator , inpainting , kp_detector , dense_motion_network ,
224- avd_network , device = device , mode = opt .mode ):
265+ for frame in tqdm (
266+ make_animation (source_image , driving_video_generator , inpainting , kp_detector , dense_motion_network ,
267+ avd_network , device = device , mode = opt .mode , autocast_dtype = autocast_dtype ,
268+ autocast = opt .autocast
269+ ), total = length ):
225270 append_frame_to_writer (frame , writer )
0 commit comments