Skip to content

Commit 9f6ade8

Browse files
fix memory error
1 parent 91f9b54 commit 9f6ade8

File tree

13 files changed

+141
-76
lines changed

13 files changed

+141
-76
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ See description of the parameters in the ```config/taichi-256.yaml```.
5252
### Training
5353
To train a model on specific dataset run:
5454
```
55-
CUDA_VISIBLE_DEVICES=0,1 python run.py --config config/dataset_name.yaml --device_ids 0,1
55+
accelerate launch run.py --config config/dataset_name.yaml --device_ids 0,1
5656
```
5757
A log folder named after the timestamp will be created. Checkpoints, loss values, reconstruction results will be saved to this folder.
5858

5959

6060
#### Training AVD network
6161
To train a model on specific dataset run:
6262
```
63-
CUDA_VISIBLE_DEVICES=0 python run.py --mode train_avd --checkpoint '{checkpoint_folder}/checkpoint.pth.tar' --config config/dataset_name.yaml
63+
accelerate launch run.py --mode train_avd --checkpoint '{checkpoint_folder}/checkpoint.pth.tar' --config config/dataset_name.yaml
6464
```
6565
Checkpoints, loss values, reconstruction results will be saved to `{checkpoint_folder}`.
6666

@@ -70,7 +70,7 @@ Checkpoints, loss values, reconstruction results will be saved to `{checkpoint_f
7070

7171
To evaluate the reconstruction performance run:
7272
```
73-
CUDA_VISIBLE_DEVICES=0 python run.py --mode reconstruction --config config/dataset_name.yaml --checkpoint '{checkpoint_folder}/checkpoint.pth.tar'
73+
accelerate launch run.py --mode reconstruction --config config/dataset_name.yaml --checkpoint '{checkpoint_folder}/checkpoint.pth.tar'
7474
```
7575
The `reconstruction` subfolder will be created in `{checkpoint_folder}`.
7676
The generated video will be stored to this folder, also generated videos will be stored in ```png``` subfolder in loss-less '.png' format for evaluation.
@@ -81,7 +81,7 @@ To compute metrics, follow instructions from [pose-evaluation](https://github.co
8181
- notebook: `demo.ipynb`, edit the config cell and run for image animation.
8282
- python:
8383
```bash
84-
CUDA_VISIBLE_DEVICES=0 python demo.py --config config/vox-256.yaml --checkpoint checkpoints/vox.pth.tar --source_image ./source.jpg --driving_video ./driving.mp4
84+
python demo.py --config config/vox-256.yaml --checkpoint checkpoints/vox.pth.tar --source_image ./source.jpg --driving_video ./driving.mp4
8585
```
8686

8787
# Acknowledgments

config/vox-256.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
dataset_params:
2-
root_dir: ../vox
2+
root_dir: vox_256
33
frame_shape: null
44
id_sampling: True
55
augmentation_params:
@@ -58,12 +58,12 @@ train_params:
5858
bg: 10
5959

6060
train_avd_params:
61-
num_epochs: 200
62-
num_repeats: 300
63-
batch_size: 256
64-
dataloader_workers: 24
65-
checkpoint_freq: 50
66-
epoch_milestones: [140, 180]
61+
num_epochs: 100
62+
num_repeats: 1
63+
batch_size: 8
64+
dataloader_workers: 6
65+
checkpoint_freq: 1
66+
epoch_milestones: [10, 20]
6767
lr: 1.0e-3
6868
lambda_shift: 1
6969
random_scale: 0.25

config/vox-512.yaml renamed to config/vox-512-finetune.yaml

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Use this file to finetune from a pretrained 256x256 model
12
dataset_params:
23
root_dir: vox
34
frame_shape: null
@@ -35,18 +36,19 @@ model_params:
3536

3637
train_params:
3738
num_epochs: 100
38-
num_repeats: 75
39+
num_repeats: 10
3940
epoch_milestones: [70, 90]
40-
lr_generator: 2.0e-4
41+
# Higher LR seems to bring problems when finetuning
42+
lr_generator: 2.0e-5
4143
batch_size: 4
4244
scales: [1, 0.5, 0.25, 0.125]
43-
dataloader_workers: 12
44-
checkpoint_freq: 1
45-
dropout_epoch: 35
45+
dataloader_workers: 6
46+
checkpoint_freq: 2
47+
dropout_epoch: 0
4648
dropout_maxp: 0.3
4749
dropout_startp: 0.1
4850
dropout_inc_epoch: 10
49-
bg_start: 10
51+
bg_start: 0
5052
transform_params:
5153
sigma_affine: 0.05
5254
sigma_tps: 0.005
@@ -59,11 +61,11 @@ train_params:
5961

6062
train_avd_params:
6163
num_epochs: 200
62-
num_repeats: 300
63-
batch_size: 256
64-
dataloader_workers: 24
65-
checkpoint_freq: 50
66-
epoch_milestones: [140, 180]
64+
num_repeats: 1
65+
batch_size: 4
66+
dataloader_workers: 6
67+
checkpoint_freq: 2
68+
epoch_milestones: [10, 20]
6769
lr: 1.0e-3
6870
lambda_shift: 1
6971
random_scale: 0.25

config/vox-768.yaml renamed to config/vox-768-finetune.yaml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Use this file to finetune from a pretrained 256x256 model
12
dataset_params:
23
root_dir: vox_768
34
frame_shape: null
@@ -35,18 +36,19 @@ model_params:
3536

3637
train_params:
3738
num_epochs: 100
38-
num_repeats: 75
39+
num_repeats: 1
3940
epoch_milestones: [70, 90]
40-
lr_generator: 2.0e-4
41+
# Higher LR seems to bring problems when finetuning
42+
lr_generator: 2.0e-5
4143
batch_size: 1
4244
scales: [1, 0.5, 0.25, 0.125]
43-
dataloader_workers: 12
45+
dataloader_workers: 6
4446
checkpoint_freq: 1
45-
dropout_epoch: 35
47+
dropout_epoch: 0
4648
dropout_maxp: 0.3
4749
dropout_startp: 0.1
4850
dropout_inc_epoch: 10
49-
bg_start: 10
51+
bg_start: 0
5052
transform_params:
5153
sigma_affine: 0.05
5254
sigma_tps: 0.005
@@ -59,10 +61,10 @@ train_params:
5961

6062
train_avd_params:
6163
num_epochs: 200
62-
num_repeats: 300
63-
batch_size: 256
64-
dataloader_workers: 24
65-
checkpoint_freq: 50
64+
num_repeats: 1
65+
batch_size: 1
66+
dataloader_workers: 6
67+
checkpoint_freq: 1
6668
epoch_milestones: [140, 180]
6769
lr: 1.0e-3
6870
lambda_shift: 1

demo.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,24 @@ def load_checkpoints(config_path, checkpoint_path, device):
6767
return inpainting, kp_detector, dense_motion_network, avd_network
6868

6969

70-
def make_animation(source_image, driving_video, inpainting_network, kp_detector, dense_motion_network, avd_network, device:torch.device, mode = 'relative'):
70+
def make_animation(source_image, driving_video_generator, inpainting_network, kp_detector, dense_motion_network, avd_network, device:torch.device, mode = 'relative'):
7171
assert mode in ['standard', 'relative', 'avd']
7272
with torch.no_grad():
7373
with torch.autocast(device_type=str(device), dtype=torch.float16):
7474
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
7575
source = source.to(device)
76-
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
76+
#driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3).to(device)
7777
kp_source = kp_detector(source)
78-
kp_driving_initial = kp_detector(driving[:, :, 0])
7978

80-
for frame_idx in tqdm(range(driving.shape[2])):
81-
driving_frame = driving[:, :, frame_idx]
82-
driving_frame = driving_frame.to(device)
79+
first_frame = True
80+
81+
for driving_frame_np in tqdm(driving_video_generator):
82+
83+
driving_frame = torch.tensor(driving_frame_np[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(
84+
device)
85+
if first_frame:
86+
kp_driving_initial = kp_detector(driving_frame)
87+
first_frame = False
8388
kp_driving = kp_detector(driving_frame)
8489
if mode == 'standard':
8590
kp_norm = kp_driving
@@ -95,7 +100,6 @@ def make_animation(source_image, driving_video, inpainting_network, kp_detector,
95100

96101
yield np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]
97102

98-
99103
def find_best_frame(source, driving, cpu):
100104
import face_alignment
101105

@@ -123,7 +127,32 @@ def normalize_kp(kp):
123127
except:
124128
pass
125129
return frame_num
130+
def read_and_resize_frames(video_path, img_shape):
131+
reader = imageio.get_reader(video_path)
132+
for frame in reader:
133+
resized_frame = resize(frame, img_shape)[..., :3]
134+
yield resized_frame
135+
reader.close()
126136

137+
def read_and_resize_frames_forward(video_path, img_shape, start_frame):
138+
reader = imageio.get_reader(video_path)
139+
for idx, frame in enumerate(reader):
140+
if idx < start_frame:
141+
continue
142+
resized_frame = resize(frame, img_shape)[..., :3]
143+
yield resized_frame
144+
reader.close()
145+
146+
def read_and_resize_frames_backward(video_path, img_shape, end_frame):
147+
reader = imageio.get_reader(video_path)
148+
frames = []
149+
for idx, frame in enumerate(reader):
150+
if idx > end_frame:
151+
break
152+
resized_frame = resize(frame, img_shape)[..., :3]
153+
frames.append(resized_frame)
154+
reader.close()
155+
return reversed(frames)
127156

128157
if __name__ == "__main__":
129158
parser = ArgumentParser()
@@ -149,12 +178,6 @@ def normalize_kp(kp):
149178
source_image = imageio.imread(opt.source_image)
150179
reader = imageio.get_reader(opt.driving_video)
151180
fps = reader.get_meta_data()['fps']
152-
driving_video = []
153-
try:
154-
for im in reader:
155-
driving_video.append(im)
156-
except RuntimeError:
157-
pass
158181
reader.close()
159182

160183
if opt.cpu:
@@ -163,7 +186,6 @@ def normalize_kp(kp):
163186
device = torch.device('cuda')
164187

165188
source_image = resize(source_image, opt.img_shape)[..., :3]
166-
driving_video = [resize(frame, opt.img_shape)[..., :3] for frame in driving_video]
167189
inpainting, kp_detector, dense_motion_network, avd_network = load_checkpoints(config_path = opt.config, checkpoint_path = opt.checkpoint, device = device)
168190

169191
def reversed_generator(generator):
@@ -175,10 +197,11 @@ def append_frame_to_writer(frame, writer):
175197

176198

177199
if opt.find_best_frame:
178-
i = find_best_frame(source_image, driving_video, opt.cpu)
200+
driving_video_generator = read_and_resize_frames(opt.driving_video, opt.img_shape)
201+
i = find_best_frame(source_image, driving_video_generator, opt.cpu)
179202
print("Best frame:", i)
180-
driving_forward = driving_video[i:]
181-
driving_backward = driving_video[:(i + 1)][::-1]
203+
driving_forward = read_and_resize_frames_forward(opt.driving_video, opt.img_shape, i)
204+
driving_backward = read_and_resize_frames_backward(opt.driving_video, opt.img_shape, i)
182205

183206
with imageio.get_writer(opt.result_video, mode='I', fps=fps) as writer:
184207
# Generate and append frames for the reversed backward animation
@@ -196,6 +219,7 @@ def append_frame_to_writer(frame, writer):
196219
append_frame_to_writer(frame, writer)
197220
else:
198221
with imageio.get_writer(opt.result_video, mode='I', fps=fps) as writer:
199-
for frame in make_animation(source_image, driving_video, inpainting, kp_detector, dense_motion_network,
222+
driving_video_generator = read_and_resize_frames(opt.driving_video, opt.img_shape)
223+
for frame in make_animation(source_image, driving_video_generator, inpainting, kp_detector, dense_motion_network,
200224
avd_network, device=device, mode=opt.mode):
201225
append_frame_to_writer(frame, writer)

frames_dataset.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_tr
7474

7575
if os.path.exists(os.path.join(root_dir, 'train')):
7676
assert os.path.exists(os.path.join(root_dir, 'test'))
77-
print("Use predefined train-test split.")
7877
if id_sampling:
7978
train_videos = {os.path.basename(video).split('#')[0] for video in
8079
os.listdir(os.path.join(root_dir, 'train'))}
@@ -84,7 +83,6 @@ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_tr
8483
test_videos = os.listdir(os.path.join(root_dir, 'test'))
8584
self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
8685
else:
87-
print("Use random train-test split.")
8886
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
8987

9088
if is_train:

logger.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Dict
2+
13
import numpy as np
24
import torch
35
import torch.nn.functional as F
@@ -15,6 +17,7 @@ class Logger:
1517
def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None,
1618
zfill_num=8, log_file_name='log.txt', models=()):
1719

20+
self.models = None
1821
self.loss_list = []
1922
self.cpk_dir = log_dir
2023
self.visualizations_dir = os.path.join(log_dir, 'train-vis')
@@ -43,19 +46,21 @@ def log_scores(self, loss_names):
4346

4447
def visualize_rec(self, inp, out):
4548
image = self.visualizer.visualize(inp['driving'], inp['source'], out)
46-
imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
49+
imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)),
50+
image)
51+
wandb.log({"image": [wandb.Image(image)]})
4752

4853
def save_cpk(self, emergent=False):
4954
cpk = {k: v.state_dict() for k, v in self.models.items()}
5055
cpk['epoch'] = self.epoch
51-
cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
56+
cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
5257
if not (os.path.exists(cpk_path) and emergent):
5358
torch.save(cpk, cpk_path)
5459

5560
@staticmethod
56-
def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network =None, kp_detector=None,
57-
bg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None,
58-
optimizer_avd=None):
61+
def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network=None, kp_detector=None,
62+
bg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None,
63+
optimizer_avd=None):
5964
checkpoint = torch.load(checkpoint_path)
6065
if inpainting_network is not None:
6166
inpainting_network.load_state_dict(checkpoint['inpainting_network'])
@@ -78,6 +83,9 @@ def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network =Non
7883
epoch = -1
7984
if 'epoch' in checkpoint:
8085
epoch = checkpoint['epoch']
86+
87+
print('Loaded checkpoint from epoch %d' % epoch)
88+
print('keys: ', checkpoint.keys())
8189
return epoch
8290

8391
def __enter__(self):
@@ -89,10 +97,12 @@ def __exit__(self, exc_type, exc_value, tb):
8997
self.log_file.close()
9098
wandb.finish()
9199

92-
def log_iter(self, losses):
100+
def log_iter(self, losses, others: Dict = None):
93101
losses = collections.OrderedDict(losses.items())
94102
self.names = list(losses.keys())
95103
self.loss_list.append(list(losses.values()))
104+
if others is not None:
105+
losses.update(others)
96106
wandb.log(losses)
97107

98108
def log_epoch(self, epoch, models, inp, out):
@@ -176,7 +186,6 @@ def visualize(self, driving, source, out):
176186
images.append((prediction, kp_norm))
177187
images.append(prediction)
178188

179-
180189
## Occlusion map
181190
if 'occlusion_map' in out:
182191
for i in range(len(out['occlusion_map'])):
@@ -192,7 +201,7 @@ def visualize(self, driving, source, out):
192201
image = out['deformed_source'][:, i].data.cpu()
193202
# import ipdb;ipdb.set_trace()
194203
image = F.interpolate(image, size=source.shape[1:3])
195-
mask = out['contribution_maps'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
204+
mask = out['contribution_maps'][:, i:(i + 1)].data.cpu().repeat(1, 3, 1, 1)
196205
mask = F.interpolate(mask, size=source.shape[1:3])
197206
image = np.transpose(image.numpy(), (0, 2, 3, 1))
198207
mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
@@ -216,4 +225,5 @@ def visualize(self, driving, source, out):
216225

217226
image = self.create_image_grid(*images)
218227
image = (255 * image).astype(np.uint8)
228+
219229
return image

modules/bg_motion_predictor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torchvision
12
from torch import nn
23
import torch
34
from torchvision import models
@@ -9,7 +10,7 @@ class BGMotionPredictor(nn.Module):
910

1011
def __init__(self):
1112
super(BGMotionPredictor, self).__init__()
12-
self.bg_encoder = models.resnet18(pretrained=False)
13+
self.bg_encoder = models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
1314
self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
1415
num_features = self.bg_encoder.fc.in_features
1516
self.bg_encoder.fc = nn.Linear(num_features, 6)

0 commit comments

Comments
 (0)