Skip to content

Commit f6deb7b

Browse files
change model definitions and training
1 parent 429a5d1 commit f6deb7b

File tree

8 files changed

+150
-49
lines changed

8 files changed

+150
-49
lines changed

config/vox-1024-finetune.yaml

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Use this file to finetune from a pretrained 256x256 model
2+
dataset_params:
3+
root_dir: ./video-preprocessing/vox2-768
4+
frame_shape: 1024,1024,3
5+
id_sampling: True
6+
augmentation_params:
7+
flip_param:
8+
horizontal_flip: True
9+
time_flip: True
10+
jitter_param:
11+
brightness: 0.1
12+
contrast: 0.1
13+
saturation: 0.1
14+
hue: 0.1
15+
16+
17+
model_params:
18+
common_params:
19+
num_tps: 10
20+
num_channels: 3
21+
bg: True
22+
multi_mask: True
23+
generator_params:
24+
block_expansion: 64
25+
max_features: 512
26+
num_down_blocks: 3
27+
dense_motion_params:
28+
block_expansion: 64
29+
max_features: 1024
30+
num_blocks: 5
31+
scale_factor: 0.25
32+
avd_network_params:
33+
id_bottle_size: 128
34+
pose_bottle_size: 128
35+
36+
37+
train_params:
38+
num_epochs: 5
39+
num_repeats: 4
40+
# Higher LR seems to bring problems when finetuning
41+
lr_generator: 2.0e-5
42+
batch_size: 1
43+
scales: [1, 0.5, 0.25, 0.125, 0.0625, 0.03125]
44+
dataloader_workers: 6
45+
checkpoint_freq: 5
46+
dropout_epoch: 2
47+
dropout_maxp: 0.3
48+
dropout_startp: 0.1
49+
dropout_inc_epoch: 1
50+
bg_start: 81
51+
freeze_kp_detector: True
52+
freeze_bg_predictor: True
53+
transform_params:
54+
sigma_affine: 0.05
55+
sigma_tps: 0.005
56+
points_tps: 5
57+
loss_weights:
58+
perceptual: [10, 10, 10, 10, 10]
59+
equivariance_value: 10
60+
warp_loss: 10
61+
bg: 10
62+
optimizer: 'adamw'
63+
optimizer_params:
64+
betas: [ 0.9, 0.999 ]
65+
weight_decay: 0.1
66+
67+
68+
train_avd_params:
69+
num_epochs: 200
70+
num_repeats: 1
71+
batch_size: 1
72+
dataloader_workers: 6
73+
checkpoint_freq: 1
74+
epoch_milestones: [140, 180]
75+
lr: 1.0e-3
76+
lambda_shift: 1
77+
random_scale: 0.25
78+
79+
visualizer_params:
80+
kp_size: 5
81+
draw_border: True
82+
colormap: 'gist_rainbow'

config/vox-256-finetune.yaml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,20 @@ model_params:
3434

3535

3636
train_params:
37-
num_epochs: 40
37+
num_epochs: 5
3838
num_repeats: 10
39-
epoch_milestones: [15, 30]
40-
lr_generator: 2.0e-4
39+
lr_generator: 2.0e-5
4140
batch_size: 16
4241
scales: [1, 0.5, 0.25, 0.125]
4342
dataloader_workers: 12
44-
checkpoint_freq: 50
45-
dropout_epoch: 2
43+
checkpoint_freq: 10
44+
dropout_epoch: 0
4645
dropout_maxp: 0.3
4746
dropout_startp: 0.1
4847
dropout_inc_epoch: 10
49-
bg_start: 5
48+
bg_start: 6
49+
freeze_kp_detector: False
50+
freeze_bg_predictor: True
5051
transform_params:
5152
sigma_affine: 0.05
5253
sigma_tps: 0.005
@@ -61,6 +62,7 @@ train_params:
6162
betas: [ 0.9, 0.999 ]
6263
weight_decay: 0.1
6364

65+
6466
train_avd_params:
6567
num_epochs: 100
6668
num_repeats: 1

config/vox-512-finetune.yaml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,21 @@ model_params:
3535

3636

3737
train_params:
38-
num_epochs: 30
38+
num_epochs: 40
3939
num_repeats: 4
40-
epoch_milestones: [20]
4140
# Higher LR seems to bring problems when finetuning
42-
lr_generator: 2.0e-5
41+
lr_generator: 2.0e-4
4342
batch_size: 4
44-
scales: [1, 0.5, 0.25, 0.125]
43+
scales: [1, 0.5, 0.25, 0.125, 0.0625]
4544
dataloader_workers: 6
4645
checkpoint_freq: 5
4746
dropout_epoch: 2
4847
dropout_maxp: 0.3
4948
dropout_startp: 0.1
5049
dropout_inc_epoch: 1
51-
bg_start: 5
50+
bg_start: 41
51+
freeze_kp_detector: True
52+
freeze_bg_predictor: True
5253
transform_params:
5354
sigma_affine: 0.05
5455
sigma_tps: 0.005

config/vox-768-finetune.yaml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Use this file to finetune from a pretrained 256x256 model
22
dataset_params:
33
root_dir: ./video-preprocessing/vox2-768
4-
frame_shape: null
4+
frame_shape: 768,768,3
55
id_sampling: True
66
augmentation_params:
77
flip_param:
@@ -36,19 +36,21 @@ model_params:
3636

3737
train_params:
3838
visualize_model: False
39-
num_epochs: 40
40-
num_repeats: 4
39+
num_epochs: 80
40+
num_repeats: 10
4141
# Higher LR seems to bring problems when finetuning
42-
lr_generator: 2.0e-5
42+
lr_generator: 3.0e-5
4343
batch_size: 2
44-
scales: [1, 0.5, 0.25, 0.125]
44+
scales: [1, 0.5, 0.25, 0.125, 0.0625]
4545
dataloader_workers: 8
4646
checkpoint_freq: 2
4747
dropout_epoch: 0
4848
dropout_maxp: 0.3
4949
dropout_startp: 0.1
5050
dropout_inc_epoch: 10
51-
bg_start: 5
51+
bg_start: 81
52+
freeze_kp_detector: True
53+
freeze_bg_predictor: True
5254
transform_params:
5355
sigma_affine: 0.05
5456
sigma_tps: 0.005

modules/bg_motion_predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self):
1212
super(BGMotionPredictor, self).__init__()
1313
self.bg_encoder = models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
1414
self.preprocess = torchvision.transforms.Compose([
15-
torchvision.transforms.Resize((256, 256)),
15+
torchvision.transforms.Resize((256, 256), antialias=True),
1616
])
1717
self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
1818
num_features = self.bg_encoder.fc.in_features

modules/keypoint_detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, num_tps, **kwargs):
1616
num_features = self.fg_encoder.fc.in_features
1717
self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2)
1818
self.preprocess = torchvision.transforms.Compose([
19-
torchvision.transforms.Resize((256, 256)),
19+
torchvision.transforms.Resize((256, 256), antialias=True),
2020
])
2121

2222

save_model_only.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737

3838
bg_predictor = None
39-
if (config['model_params']['common_params']['bg']):
39+
if 'bg_predictor' in checkpoint:
4040
bg_predictor = BGMotionPredictor()
4141

4242
avd_network = None

train.py

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
accelerator = Accelerator()
1717

18+
1819
def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset,
1920
optimizer_class=torch.optim.Adam
2021
):
@@ -44,23 +45,35 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne
4445
else:
4546
start_epoch = 0
4647

47-
48-
48+
freeze_kp_detector = train_params.get('freeze_kp_detector', False)
49+
freeze_bg_predictor = train_params.get('freeze_bg_predictor', False)
50+
if freeze_kp_detector:
51+
print('freeze kp detector')
52+
kp_detector.eval()
53+
for param in kp_detector.parameters():
54+
param.requires_grad = False
55+
if freeze_bg_predictor:
56+
print('freeze bg predictor')
57+
bg_predictor.eval()
58+
for param in bg_predictor.parameters():
59+
param.requires_grad = False
4960

5061
if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
5162
dataset = DatasetRepeater(dataset, train_params['num_repeats'])
5263
dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True,
5364
num_workers=train_params['dataloader_workers'], drop_last=True)
5465

5566
scheduler_optimizer = OneCycleLR(optimizer, max_lr=train_params['lr_generator'],
56-
total_steps=(len(dataset) // train_params['batch_size']) * train_params['num_epochs'],
57-
last_epoch=start_epoch-1)
67+
total_steps=(len(dataset) // train_params['batch_size']) * train_params[
68+
'num_epochs'],
69+
last_epoch=start_epoch - 1)
5870

5971
scheduler_bg_predictor = None
6072
if bg_predictor:
6173
scheduler_bg_predictor = OneCycleLR(optimizer_bg_predictor, max_lr=train_params['lr_generator'],
62-
total_steps=(len(dataset) // train_params['batch_size']) * train_params['num_epochs'],
63-
last_epoch=start_epoch-1)
74+
total_steps=(len(dataset) // train_params['batch_size']) * train_params[
75+
'num_epochs'],
76+
last_epoch=start_epoch - 1)
6477
bg_predictor, optimizer_bg_predictor = accelerator.prepare(bg_predictor, optimizer_bg_predictor)
6578

6679
generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network,
@@ -75,16 +88,21 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne
7588
if train_params.get('visualize_model', False):
7689
# visualize graph
7790
sample = next(iter(dataloader))
78-
draw_graph(generator_full, input_data=[sample, 100], save_graph=True, directory=log_dir, graph_name='generator_full')
79-
draw_graph(kp_detector, input_data=[sample['driving']], save_graph=True, directory=log_dir, graph_name='kp_detector')
91+
draw_graph(generator_full, input_data=[sample, 100], save_graph=True, directory=log_dir,
92+
graph_name='generator_full')
93+
draw_graph(kp_detector, input_data=[sample['driving']], save_graph=True, directory=log_dir,
94+
graph_name='kp_detector')
8095
kp_driving = kp_detector(sample['driving'])
8196
kp_source = kp_detector(sample['source'])
8297
bg_param = bg_predictor(sample['source'], sample['driving'])
83-
dense_motion_param = {'source_image': sample['source'], 'kp_driving': kp_driving, 'kp_source': kp_source, 'bg_param': bg_param,
84-
'dropout_flag' : False, 'dropout_p' : 0.0}
98+
dense_motion_param = {'source_image': sample['source'], 'kp_driving': kp_driving, 'kp_source': kp_source,
99+
'bg_param': bg_param,
100+
'dropout_flag': False, 'dropout_p': 0.0}
85101
dense_motion = dense_motion_network(**dense_motion_param)
86-
draw_graph(dense_motion_network, input_data=dense_motion_param, save_graph=True, directory=log_dir, graph_name='dense_motion_network')
87-
draw_graph(inpainting_network, input_data=[sample['source'], dense_motion], save_graph=True, directory=log_dir, graph_name='inpainting_network')
102+
draw_graph(dense_motion_network, input_data=dense_motion_param, save_graph=True, directory=log_dir,
103+
graph_name='dense_motion_network')
104+
draw_graph(inpainting_network, input_data=[sample['source'], dense_motion], save_graph=True, directory=log_dir,
105+
graph_name='inpainting_network')
88106

89107
with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'],
90108
checkpoint_freq=train_params['checkpoint_freq'],
@@ -100,14 +118,18 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne
100118

101119
clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type=math.inf)
102120
clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type=math.inf)
103-
if bg_predictor and epoch >= bg_start:
121+
if bg_predictor and epoch >= bg_start and not freeze_bg_predictor:
104122
clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type=math.inf)
105123

106124
optimizer.step()
107-
optimizer.zero_grad()
108-
if bg_predictor and epoch >= bg_start:
125+
126+
if bg_predictor and epoch >= bg_start and not freeze_bg_predictor:
109127
optimizer_bg_predictor.step()
110128
optimizer_bg_predictor.zero_grad()
129+
scheduler_bg_predictor.step()
130+
131+
optimizer.zero_grad()
132+
scheduler_optimizer.step()
111133

112134
losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
113135
lrs = {
@@ -116,23 +138,15 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne
116138
}
117139
logger.log_iter(losses=losses, others=lrs)
118140

119-
scheduler_optimizer.step()
120-
if bg_predictor:
121-
scheduler_bg_predictor.step()
141+
122142

123143
model_save = {
124-
'inpainting_network': inpainting_network,
125-
'dense_motion_network': dense_motion_network,
126-
'kp_detector': kp_detector,
144+
'inpainting_network': accelerator.unwrap_model(inpainting_network),
145+
'dense_motion_network': accelerator.unwrap_model(dense_motion_network),
146+
'kp_detector': accelerator.unwrap_model(kp_detector),
127147
'optimizer': optimizer,
148+
'bg_predictor': accelerator.unwrap_model(bg_predictor) if bg_predictor else None,
149+
'optimizer_bg_predictor': optimizer_bg_predictor
128150
}
129-
if bg_predictor and epoch >= bg_start:
130-
model_save['bg_predictor'] = bg_predictor
131-
model_save['optimizer_bg_predictor'] = optimizer_bg_predictor
132-
133-
accelerator.save_state(log_dir)
134-
135151

136152
logger.log_epoch(epoch, model_save, inp=x, out=generated)
137-
138-

0 commit comments

Comments
 (0)