Skip to content

Commit d3b6d0e

Browse files
Merge branch 'main' of github.com:thegenerativegeneration/Thin-Plate-Spline-Motion-Model
2 parents 4630846 + 25099de commit d3b6d0e

File tree

6 files changed

+235
-54
lines changed

6 files changed

+235
-54
lines changed

assets/driving_short.mp4

50.2 KB
Binary file not shown.

checkpoints/.gitkeep

Whitespace-only changes.

checkpoints/256/.gitkeep

Whitespace-only changes.

demo.py

Lines changed: 91 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import logging
12
import os
2-
import shutil
3+
from contextlib import nullcontext
34

45
import matplotlib
6+
57
matplotlib.use('Agg')
6-
import sys
78
import yaml
89
from argparse import ArgumentParser
910
from tqdm import tqdm
@@ -17,12 +18,12 @@
1718
from modules.keypoint_detector import KPDetector
1819
from modules.dense_motion import DenseMotionNetwork
1920
from modules.avd_network import AVDNetwork
21+
from utils import VideoReader, VideoWriter
2022

21-
if sys.version_info[0] < 3:
22-
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
23+
logger = logging.getLogger("TPSMM")
2324

24-
def relative_kp(kp_source, kp_driving, kp_driving_initial):
2525

26+
def relative_kp(kp_source, kp_driving, kp_driving_initial):
2627
source_area = ConvexHull(kp_source['fg_kp'][0].data.cpu().numpy()).volume
2728
driving_area = ConvexHull(kp_driving_initial['fg_kp'][0].data.cpu().numpy()).volume
2829
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
@@ -35,12 +36,13 @@ def relative_kp(kp_source, kp_driving, kp_driving_initial):
3536

3637
return kp_new
3738

39+
3840
def load_checkpoints(config_path, checkpoint_path, device):
3941
with open(config_path) as f:
4042
config = yaml.full_load(f)
4143

4244
inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
43-
**config['model_params']['common_params'])
45+
**config['model_params']['common_params'])
4446
kp_detector = KPDetector(**config['model_params']['common_params'])
4547
dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
4648
**config['model_params']['dense_motion_params'])
@@ -50,30 +52,36 @@ def load_checkpoints(config_path, checkpoint_path, device):
5052
dense_motion_network.to(device)
5153
inpainting.to(device)
5254
avd_network.to(device)
53-
55+
5456
checkpoint = torch.load(checkpoint_path, map_location=device)
55-
57+
5658
inpainting.load_state_dict(checkpoint['inpainting_network'])
5759
kp_detector.load_state_dict(checkpoint['kp_detector'])
5860
dense_motion_network.load_state_dict(checkpoint['dense_motion_network'])
5961
if 'avd_network' in checkpoint:
6062
avd_network.load_state_dict(checkpoint['avd_network'])
61-
63+
6264
inpainting.eval()
6365
kp_detector.eval()
6466
dense_motion_network.eval()
6567
avd_network.eval()
66-
68+
6769
return inpainting, kp_detector, dense_motion_network, avd_network
6870

6971

70-
def make_animation(source_image, driving_video_generator, inpainting_network, kp_detector, dense_motion_network, avd_network, device:torch.device, mode = 'relative'):
72+
def make_animation(source_image, driving_video_generator, inpainting_network, kp_detector, dense_motion_network,
73+
avd_network, device: torch.device, mode='relative', autocast_dtype=torch.float16, autocast=False):
7174
assert mode in ['standard', 'relative', 'avd']
7275
with torch.no_grad():
73-
with torch.autocast(device_type=str(device), dtype=torch.float16):
76+
77+
if autocast:
78+
autocast_context = torch.autocast(device_type=str(device), dtype=autocast_dtype)
79+
else:
80+
autocast_context = nullcontext()
81+
82+
with autocast_context:
7483
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
7584
source = source.to(device)
76-
#driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
7785
kp_source = kp_detector(source)
7886

7987
first_frame = True
@@ -88,18 +96,19 @@ def make_animation(source_image, driving_video_generator, inpainting_network, kp
8896
kp_driving = kp_detector(driving_frame)
8997
if mode == 'standard':
9098
kp_norm = kp_driving
91-
elif mode=='relative':
99+
elif mode == 'relative':
92100
kp_norm = relative_kp(kp_source=kp_source, kp_driving=kp_driving,
93-
kp_driving_initial=kp_driving_initial)
101+
kp_driving_initial=kp_driving_initial)
94102
elif mode == 'avd':
95103
kp_norm = avd_network(kp_source, kp_driving)
96104
dense_motion = dense_motion_network(source_image=source, kp_driving=kp_norm,
97-
kp_source=kp_source, bg_param = None,
98-
dropout_flag = False)
105+
kp_source=kp_source, bg_param=None,
106+
dropout_flag=False)
99107
out = inpainting_network(source, dense_motion)
100108

101109
yield np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]
102110

111+
103112
def find_best_frame(source, driving, cpu):
104113
import face_alignment
105114

@@ -110,11 +119,11 @@ def normalize_kp(kp):
110119
kp[:, :2] = kp[:, :2] / area
111120
return kp
112121

113-
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
114-
device= 'cpu' if cpu else 'cuda')
122+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=True,
123+
device='cpu' if cpu else 'cuda')
115124
kp_source = fa.get_landmarks(255 * source)[0]
116125
kp_source = normalize_kp(kp_source)
117-
norm = float('inf')
126+
norm = float('inf')
118127
frame_num = 0
119128
for i, image in tqdm(enumerate(driving)):
120129
try:
@@ -127,24 +136,30 @@ def normalize_kp(kp):
127136
except:
128137
pass
129138
return frame_num
139+
140+
130141
def read_and_resize_frames(video_path, img_shape):
131-
reader = imageio.get_reader(video_path)
142+
reader = VideoReader(video_path)
132143
for frame in reader:
133144
resized_frame = resize(frame, img_shape)[..., :3]
134145
yield resized_frame
146+
135147
reader.close()
136148

149+
137150
def read_and_resize_frames_forward(video_path, img_shape, start_frame):
138-
reader = imageio.get_reader(video_path)
151+
reader = VideoReader(video_path)
139152
for idx, frame in enumerate(reader):
140153
if idx < start_frame:
141154
continue
142155
resized_frame = resize(frame, img_shape)[..., :3]
143156
yield resized_frame
144157
reader.close()
145158

159+
146160
def read_and_resize_frames_backward(video_path, img_shape, end_frame):
147-
reader = imageio.get_reader(video_path)
161+
reader = VideoReader(video_path)
162+
148163
frames = []
149164
for idx, frame in enumerate(reader):
150165
if idx > end_frame:
@@ -154,72 +169,102 @@ def read_and_resize_frames_backward(video_path, img_shape, end_frame):
154169
reader.close()
155170
return reversed(frames)
156171

172+
157173
if __name__ == "__main__":
158174
parser = ArgumentParser()
159175
parser.add_argument("--config", required=True, help="path to config")
160176
parser.add_argument("--checkpoint", default='checkpoints/vox.pth.tar', help="path to checkpoint to restore")
161177

162178
parser.add_argument("--source_image", default='./assets/source.png', help="path to source image")
163-
parser.add_argument("--driving_video", default='./assets/driving.mp4', help="path to driving video")
164-
parser.add_argument("--result_video", default='./result.mp4', help="path to output")
165-
179+
parser.add_argument("--driving_video", default='./assets/driving.mp4', help="path to driving video or folder of images")
180+
parser.add_argument("--result_video", default='./result.mp4', help="path to output. Can be file name or folder.")
181+
166182
parser.add_argument("--img_shape", default="256,256", type=lambda x: list(map(int, x.split(','))),
167183
help='Shape of image, that the model was trained on.')
168-
169-
parser.add_argument("--mode", default='relative', choices=['standard', 'relative', 'avd'], help="Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result")
170-
171-
parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
184+
185+
parser.add_argument("--mode", default='relative', choices=['standard', 'relative', 'avd'],
186+
help="Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result")
187+
188+
parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
172189
help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
173190

174191
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
192+
parser.add_argument("--autocast", dest="autocast", action="store_true", help="Autocast mode.")
193+
175194

176195
opt = parser.parse_args()
177196

178197
source_image = imageio.imread(opt.source_image)
179-
reader = imageio.get_reader(opt.driving_video)
180-
fps = reader.get_meta_data()['fps']
181-
reader.close()
182-
183-
if opt.cpu:
198+
199+
if os.path.isdir(opt.driving_video):
200+
fps = 30
201+
length = len(os.listdir(opt.driving_video))
202+
elif "%" in opt.driving_video:
203+
fps = 30
204+
length = len(os.path.dirname(opt.driving_video))
205+
else:
206+
reader = imageio.get_reader(opt.driving_video, mode='I')
207+
208+
fps = reader.get_meta_data().get('fps', 30)
209+
length = int(reader.get_meta_data().get('duration', 0)) * int(fps)
210+
211+
reader.close()
212+
213+
if opt.cpu and opt.autocast:
214+
autocast_dtype = torch.bfloat16
215+
else:
216+
autocast_dtype = torch.float16
217+
218+
if opt.cpu or torch.cuda.device_count() == 0:
184219
device = torch.device('cpu')
185220
else:
186221
device = torch.device('cuda')
187-
222+
188223
source_image = resize(source_image, opt.img_shape)[..., :3]
189-
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = opt.config, checkpoint_path = opt.checkpoint, device = device)
224+
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path=opt.config,
225+
checkpoint_path=opt.checkpoint,
226+
device=device)
227+
190228

191229
def reversed_generator(generator):
192230
frames = list(generator)
193231
return reversed(frames)
194232

233+
195234
def append_frame_to_writer(frame, writer):
196235
writer.append_data(img_as_ubyte(frame))
197236

198-
237+
writer = VideoWriter(opt.result_video, mode='I', fps=fps)
199238
if opt.find_best_frame:
200239
driving_video_generator = read_and_resize_frames(opt.driving_video, opt.img_shape)
201240
i = find_best_frame(source_image, driving_video_generator, opt.cpu)
202-
print("Best frame:", i)
203241
driving_forward = read_and_resize_frames_forward(opt.driving_video, opt.img_shape, i)
204242
driving_backward = read_and_resize_frames_backward(opt.driving_video, opt.img_shape, i)
205243

206-
with imageio.get_writer(opt.result_video, mode='I', fps=fps) as writer:
244+
with writer:
207245
# Generate and append frames for the reversed backward animation
208246
backward_animation = make_animation(source_image, driving_backward, inpainting, kp_detector,
209-
dense_motion_network, avd_network, device=device, mode=opt.mode)
247+
dense_motion_network, avd_network, device=device, mode=opt.mode,
248+
autocast_dtype=autocast_dtype, autocast=opt.autocast)
249+
210250
for frame in reversed_generator(backward_animation):
211251
append_frame_to_writer(frame, writer)
212252

213253
# Generate and append frames for forward animation, skipping the first frame
214-
for idx, frame in enumerate(
254+
for idx, frame in tqdm(enumerate(
215255
make_animation(source_image, driving_forward, inpainting, kp_detector, dense_motion_network,
216-
avd_network, device=device, mode=opt.mode)):
256+
avd_network, device=device, mode=opt.mode, autocast_dtype=autocast_dtype,
257+
autocast=opt.autocast
258+
)), total=length):
217259
if idx == 0:
218260
continue
219261
append_frame_to_writer(frame, writer)
220262
else:
221-
with imageio.get_writer(opt.result_video, mode='I', fps=fps) as writer:
263+
with writer:
222264
driving_video_generator = read_and_resize_frames(opt.driving_video, opt.img_shape)
223-
for frame in make_animation(source_image, driving_video_generator, inpainting, kp_detector, dense_motion_network,
224-
avd_network, device=device, mode=opt.mode):
265+
for frame in tqdm(
266+
make_animation(source_image, driving_video_generator, inpainting, kp_detector, dense_motion_network,
267+
avd_network, device=device, mode=opt.mode, autocast_dtype=autocast_dtype,
268+
autocast=opt.autocast
269+
), total=length):
225270
append_frame_to_writer(frame, writer)

requirements.txt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
cffi==1.14.6
22
cycler==0.10.0
33
decorator==5.1.0
4-
face-alignment==1.3.5
4+
face-alignment==1.4.0
55
imageio
66
imageio-ffmpeg
77
kiwisolver==1.3.2
88
matplotlib==3.4.3
99
networkx==2.6.3
10-
numpy==1.20.3
10+
numpy
1111
pandas==1.3.3
1212
Pillow
1313
pycparser==2.20
1414
pyparsing==2.4.7
1515
python-dateutil==2.8.2
1616
pytz==2021.1
17-
PyWavelets==1.1.1
17+
PyWavelets
1818
PyYAML==5.4.1
19-
scikit-image==0.18.3
20-
scikit-learn==1.0
21-
scipy==1.7.1
19+
scikit-image
20+
scikit-learn
21+
scipy
2222
six==1.16.0
23-
#torch==1.11.0+cu113
24-
#torchvision==0.12.0+cu113
23+
torch==2.0.1
24+
torchvision
2525
tqdm==4.62.3
2626
wandb

0 commit comments

Comments
 (0)