Skip to content

Commit 429a5d1

Browse files
change model definitions and training
1 parent d3b6d0e commit 429a5d1

File tree

12 files changed

+288
-65
lines changed

12 files changed

+288
-65
lines changed

config/vox-256-finetune.yaml

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
dataset_params:
2+
root_dir: ./video-preprocessing/vox2-768
3+
frame_shape: 256,256,3
4+
id_sampling: True
5+
augmentation_params:
6+
flip_param:
7+
horizontal_flip: True
8+
time_flip: True
9+
jitter_param:
10+
brightness: 0.1
11+
contrast: 0.1
12+
saturation: 0.1
13+
hue: 0.1
14+
15+
16+
model_params:
17+
common_params:
18+
num_tps: 10
19+
num_channels: 3
20+
bg: True
21+
multi_mask: True
22+
generator_params:
23+
block_expansion: 64
24+
max_features: 512
25+
num_down_blocks: 3
26+
dense_motion_params:
27+
block_expansion: 64
28+
max_features: 1024
29+
num_blocks: 5
30+
scale_factor: 0.25
31+
avd_network_params:
32+
id_bottle_size: 128
33+
pose_bottle_size: 128
34+
35+
36+
train_params:
37+
num_epochs: 40
38+
num_repeats: 10
39+
epoch_milestones: [15, 30]
40+
lr_generator: 2.0e-4
41+
batch_size: 16
42+
scales: [1, 0.5, 0.25, 0.125]
43+
dataloader_workers: 12
44+
checkpoint_freq: 50
45+
dropout_epoch: 2
46+
dropout_maxp: 0.3
47+
dropout_startp: 0.1
48+
dropout_inc_epoch: 10
49+
bg_start: 5
50+
transform_params:
51+
sigma_affine: 0.05
52+
sigma_tps: 0.005
53+
points_tps: 5
54+
loss_weights:
55+
perceptual: [10, 10, 10, 10, 10]
56+
equivariance_value: 10
57+
warp_loss: 10
58+
bg: 10
59+
optimizer: 'adamw'
60+
optimizer_params:
61+
betas: [ 0.9, 0.999 ]
62+
weight_decay: 0.1
63+
64+
train_avd_params:
65+
num_epochs: 100
66+
num_repeats: 1
67+
batch_size: 8
68+
dataloader_workers: 6
69+
checkpoint_freq: 1
70+
epoch_milestones: [10, 20]
71+
lr: 1.0e-3
72+
lambda_shift: 1
73+
random_scale: 0.25
74+
75+
visualizer_params:
76+
kp_size: 5
77+
draw_border: True
78+
colormap: 'gist_rainbow'

config/vox-256.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ train_params:
5656
equivariance_value: 10
5757
warp_loss: 10
5858
bg: 10
59+
optimizer: 'adam'
60+
optimizer_params:
61+
betas: [ 0.5, 0.999 ]
62+
weight_decay: 1e-4
5963

6064
train_avd_params:
6165
num_epochs: 100

config/vox-512-finetune.yaml

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Use this file to finetune from a pretrained 256x256 model
22
dataset_params:
3-
root_dir: vox
4-
frame_shape: null
3+
root_dir: ./video-preprocessing/vox2-768
4+
frame_shape: 512,512,3
55
id_sampling: True
66
augmentation_params:
77
flip_param:
@@ -35,20 +35,20 @@ model_params:
3535

3636

3737
train_params:
38-
num_epochs: 100
39-
num_repeats: 10
40-
epoch_milestones: [70, 90]
38+
num_epochs: 30
39+
num_repeats: 4
40+
epoch_milestones: [20]
4141
# Higher LR seems to bring problems when finetuning
4242
lr_generator: 2.0e-5
4343
batch_size: 4
4444
scales: [1, 0.5, 0.25, 0.125]
4545
dataloader_workers: 6
46-
checkpoint_freq: 2
47-
dropout_epoch: 0
46+
checkpoint_freq: 5
47+
dropout_epoch: 2
4848
dropout_maxp: 0.3
4949
dropout_startp: 0.1
50-
dropout_inc_epoch: 10
51-
bg_start: 0
50+
dropout_inc_epoch: 1
51+
bg_start: 5
5252
transform_params:
5353
sigma_affine: 0.05
5454
sigma_tps: 0.005
@@ -58,13 +58,17 @@ train_params:
5858
equivariance_value: 10
5959
warp_loss: 10
6060
bg: 10
61+
optimizer: 'adamw'
62+
optimizer_params:
63+
betas: [0.9, 0.999]
64+
weight_decay: 0.1
6165

6266
train_avd_params:
6367
num_epochs: 200
6468
num_repeats: 1
6569
batch_size: 4
6670
dataloader_workers: 6
67-
checkpoint_freq: 2
71+
checkpoint_freq: 10
6872
epoch_milestones: [10, 20]
6973
lr: 1.0e-3
7074
lambda_shift: 1

config/vox-768-finetune.yaml

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Use this file to finetune from a pretrained 256x256 model
22
dataset_params:
3-
root_dir: vox_768
3+
root_dir: ./video-preprocessing/vox2-768
44
frame_shape: null
55
id_sampling: True
66
augmentation_params:
@@ -35,20 +35,20 @@ model_params:
3535

3636

3737
train_params:
38-
num_epochs: 100
39-
num_repeats: 1
40-
epoch_milestones: [70, 90]
38+
visualize_model: False
39+
num_epochs: 40
40+
num_repeats: 4
4141
# Higher LR seems to bring problems when finetuning
4242
lr_generator: 2.0e-5
43-
batch_size: 1
43+
batch_size: 2
4444
scales: [1, 0.5, 0.25, 0.125]
45-
dataloader_workers: 6
46-
checkpoint_freq: 1
45+
dataloader_workers: 8
46+
checkpoint_freq: 2
4747
dropout_epoch: 0
4848
dropout_maxp: 0.3
4949
dropout_startp: 0.1
5050
dropout_inc_epoch: 10
51-
bg_start: 0
51+
bg_start: 5
5252
transform_params:
5353
sigma_affine: 0.05
5454
sigma_tps: 0.005
@@ -58,6 +58,10 @@ train_params:
5858
equivariance_value: 10
5959
warp_loss: 10
6060
bg: 10
61+
optimizer: 'adamw'
62+
optimizer_params:
63+
betas: [ 0.9, 0.999 ]
64+
weight_decay: 0.1
6165

6266
train_avd_params:
6367
num_epochs: 200
@@ -73,4 +77,4 @@ train_avd_params:
7377
visualizer_params:
7478
kp_size: 5
7579
draw_border: True
76-
colormap: 'gist_rainbow'
80+
colormap: 'gist_rainbow'

frames_dataset.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_tr
6767
random_seed=0, pairs_list=None, augmentation_params=None):
6868
self.root_dir = root_dir
6969
self.videos = os.listdir(root_dir)
70+
if type(frame_shape) == str:
71+
frame_shape = tuple(map(int, frame_shape.split(',')))
7072
self.frame_shape = frame_shape
7173
print(self.frame_shape)
7274
self.pairs_list = pairs_list
@@ -115,7 +117,13 @@ def __getitem__(self, idx):
115117

116118
frames = os.listdir(path)
117119
num_frames = len(frames)
118-
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
120+
# use more frames that are different from each other to speed up training
121+
min_frames_apart = num_frames // 4
122+
first_frame_idx = np.random.choice(num_frames - min_frames_apart)
123+
second_frame_idx = np.random.choice(range(first_frame_idx + min_frames_apart, num_frames))
124+
frame_idx = np.array([first_frame_idx, second_frame_idx])
125+
np.random.shuffle(frame_idx)
126+
#frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
119127

120128
if self.frame_shape is not None:
121129
resize_fn = partial(resize, output_shape=self.frame_shape)

logger.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ class Logger:
1717
def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None,
1818
zfill_num=8, log_file_name='log.txt', models=()):
1919

20-
self.models = None
20+
self.models = models
2121
self.loss_list = []
2222
self.cpk_dir = log_dir
2323
self.visualizations_dir = os.path.join(log_dir, 'train-vis')
2424
if not os.path.exists(self.visualizations_dir):
2525
os.makedirs(self.visualizations_dir)
26+
print("Visualizations will be saved in %s" % self.visualizations_dir)
2627
self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
2728
self.zfill_num = zfill_num
2829
self.visualizer = Visualizer(**visualizer_params)
@@ -46,9 +47,10 @@ def log_scores(self, loss_names):
4647

4748
def visualize_rec(self, inp, out):
4849
image = self.visualizer.visualize(inp['driving'], inp['source'], out)
50+
wandb.log({"image": [wandb.Image(image)]})
4951
imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)),
5052
image)
51-
wandb.log({"image": [wandb.Image(image)]})
53+
5254

5355
def save_cpk(self, emergent=False):
5456
cpk = {k: v.state_dict() for k, v in self.models.items()}

modules/bg_motion_predictor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,24 @@ class BGMotionPredictor(nn.Module):
1111
def __init__(self):
1212
super(BGMotionPredictor, self).__init__()
1313
self.bg_encoder = models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
14+
self.preprocess = torchvision.transforms.Compose([
15+
torchvision.transforms.Resize((256, 256)),
16+
])
1417
self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
1518
num_features = self.bg_encoder.fc.in_features
1619
self.bg_encoder.fc = nn.Linear(num_features, 6)
1720
self.bg_encoder.fc.weight.data.zero_()
1821
self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
1922

2023
def forward(self, source_image, driving_image):
24+
25+
2126
bs = source_image.shape[0]
2227
out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type())
28+
29+
source_image = self.preprocess(source_image)
30+
driving_image = self.preprocess(driving_image)
31+
2332
prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1))
2433
out[:, :2, :] = prediction.view(bs, 2, 3)
2534
return out

modules/keypoint_detector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@ def __init__(self, num_tps, **kwargs):
1515
self.fg_encoder = models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
1616
num_features = self.fg_encoder.fc.in_features
1717
self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2)
18+
self.preprocess = torchvision.transforms.Compose([
19+
torchvision.transforms.Resize((256, 256)),
20+
])
1821

1922

2023
def forward(self, image):
24+
image = self.preprocess(image)
2125

2226
fg_kp = self.fg_encoder(image)
2327
bs, _, = fg_kp.shape

modules/util.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,10 @@ def __init__(self, in_features, kernel_size, padding):
150150

151151
def forward(self, x):
152152
out = self.norm1(x)
153-
out = F.relu(out)
153+
out = F.mish(out)
154154
out = self.conv1(out)
155155
out = self.norm2(out)
156-
out = F.relu(out)
156+
out = F.mish(out)
157157
out = self.conv2(out)
158158
out += x
159159
return out
@@ -172,10 +172,10 @@ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1
172172
self.norm = nn.InstanceNorm2d(out_features, affine=True)
173173

174174
def forward(self, x):
175-
out = F.interpolate(x, scale_factor=2)
175+
out = F.interpolate(x, scale_factor=2, mode='nearest')
176176
out = self.conv(out)
177177
out = self.norm(out)
178-
out = F.relu(out)
178+
out = F.mish(out)
179179
return out
180180

181181

@@ -194,7 +194,7 @@ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1
194194
def forward(self, x):
195195
out = self.conv(x)
196196
out = self.norm(out)
197-
out = F.relu(out)
197+
out = F.mish(out)
198198
out = self.pool(out)
199199
return out
200200

@@ -213,7 +213,7 @@ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1
213213
def forward(self, x):
214214
out = self.conv(x)
215215
out = self.norm(out)
216-
out = F.relu(out)
216+
out = F.mish(out)
217217
return out
218218

219219

run.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from train_avd import train_avd
1919
from reconstruction import reconstruction
2020
import os
21+
from torchinfo import summary
2122
import bitsandbytes as bnb
2223

2324
optimizer_choices = {
@@ -37,7 +38,7 @@
3738
parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"])
3839
parser.add_argument("--log_dir", default='log', help="path to log into")
3940
parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
40-
parser.add_argument("--optimizer_class", default="adam", choices=optimizer_choices.keys())
41+
parser.add_argument("--detect_anomaly", action="store_true", help="detect anomaly in autograd")
4142

4243

4344
opt = parser.parse_args()
@@ -50,6 +51,9 @@
5051
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
5152
log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
5253

54+
if opt.detect_anomaly:
55+
torch.autograd.set_detect_anomaly(True)
56+
5357
inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
5458
**config['model_params']['common_params'])
5559

@@ -76,7 +80,17 @@
7680
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
7781
copy(opt.config, log_dir)
7882

79-
optimizer_class = optimizer_choices[opt.optimizer_class]
83+
optimizer_class = optimizer_choices[config['train_params']['optimizer']]
84+
85+
print("Inpainting Network:")
86+
summary(inpainting)
87+
print("Keypoint Detector:")
88+
summary(kp_detector)
89+
print("Dense Motion Network:")
90+
summary(dense_motion_network)
91+
if bg_predictor is not None:
92+
print("Background Predictor:")
93+
summary(bg_predictor)
8094

8195
if opt.mode == 'train':
8296
print("Training...")
@@ -90,3 +104,4 @@
90104
print("Reconstruction...")
91105
#TODO: update to accelerate
92106
reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
107+

0 commit comments

Comments
 (0)