From 463084650abf79babd7594d39a7dfcb3cf06926c Mon Sep 17 00:00:00 2001 From: TGG Date: Tue, 4 Apr 2023 18:21:03 +0200 Subject: [PATCH 01/30] add video dataset downloader as submodule --- .gitmodules | 3 +++ video-preprocessing | 1 + 2 files changed, 4 insertions(+) create mode 100644 .gitmodules create mode 160000 video-preprocessing diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..a6546b3 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "video-preprocessing"] + path = video-preprocessing + url = https://github.com/AliaksandrSiarohin/video-preprocessing.git diff --git a/video-preprocessing b/video-preprocessing new file mode 160000 index 0000000..ac40aac --- /dev/null +++ b/video-preprocessing @@ -0,0 +1 @@ +Subproject commit ac40aac58657a3d8db85421cd4afcf465e86ead1 From 429a5d17c7d4d07c8e35992672efe951f2a4d0c5 Mon Sep 17 00:00:00 2001 From: TGG Date: Sat, 8 Jul 2023 13:57:01 +0200 Subject: [PATCH 02/30] change model definitions and training --- config/vox-256-finetune.yaml | 78 +++++++++++++++++++++++++++ config/vox-256.yaml | 4 ++ config/vox-512-finetune.yaml | 24 +++++---- config/vox-768-finetune.yaml | 22 ++++---- frames_dataset.py | 10 +++- logger.py | 6 ++- modules/bg_motion_predictor.py | 9 ++++ modules/keypoint_detector.py | 4 ++ modules/util.py | 12 ++--- run.py | 19 ++++++- save_model_only.py | 68 ++++++++++++++++++++++++ train.py | 97 ++++++++++++++++++++++------------ 12 files changed, 288 insertions(+), 65 deletions(-) create mode 100644 config/vox-256-finetune.yaml create mode 100644 save_model_only.py diff --git a/config/vox-256-finetune.yaml b/config/vox-256-finetune.yaml new file mode 100644 index 0000000..daaeec4 --- /dev/null +++ b/config/vox-256-finetune.yaml @@ -0,0 +1,78 @@ +dataset_params: + root_dir: ./video-preprocessing/vox2-768 + frame_shape: 256,256,3 + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 3 + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + +train_params: + num_epochs: 40 + num_repeats: 10 + epoch_milestones: [15, 30] + lr_generator: 2.0e-4 + batch_size: 16 + scales: [1, 0.5, 0.25, 0.125] + dataloader_workers: 12 + checkpoint_freq: 50 + dropout_epoch: 2 + dropout_maxp: 0.3 + dropout_startp: 0.1 + dropout_inc_epoch: 10 + bg_start: 5 + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + warp_loss: 10 + bg: 10 + optimizer: 'adamw' + optimizer_params: + betas: [ 0.9, 0.999 ] + weight_decay: 0.1 + +train_avd_params: + num_epochs: 100 + num_repeats: 1 + batch_size: 8 + dataloader_workers: 6 + checkpoint_freq: 1 + epoch_milestones: [10, 20] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' \ No newline at end of file diff --git a/config/vox-256.yaml b/config/vox-256.yaml index 7b28e83..9d555a6 100644 --- a/config/vox-256.yaml +++ b/config/vox-256.yaml @@ -56,6 +56,10 @@ train_params: equivariance_value: 10 warp_loss: 10 bg: 10 + optimizer: 'adam' + optimizer_params: + betas: [ 0.5, 0.999 ] + weight_decay: 1e-4 train_avd_params: num_epochs: 100 diff --git a/config/vox-512-finetune.yaml b/config/vox-512-finetune.yaml index 408d687..5deeb04 100644 --- a/config/vox-512-finetune.yaml +++ b/config/vox-512-finetune.yaml @@ -1,7 +1,7 @@ # Use this file to finetune from a pretrained 256x256 model dataset_params: - root_dir: vox - frame_shape: null + root_dir: ./video-preprocessing/vox2-768 + frame_shape: 512,512,3 id_sampling: True augmentation_params: flip_param: @@ -35,20 +35,20 @@ model_params: train_params: - num_epochs: 100 - num_repeats: 10 - epoch_milestones: [70, 90] + num_epochs: 30 + num_repeats: 4 + epoch_milestones: [20] # Higher LR seems to bring problems when finetuning lr_generator: 2.0e-5 batch_size: 4 scales: [1, 0.5, 0.25, 0.125] dataloader_workers: 6 - checkpoint_freq: 2 - dropout_epoch: 0 + checkpoint_freq: 5 + dropout_epoch: 2 dropout_maxp: 0.3 dropout_startp: 0.1 - dropout_inc_epoch: 10 - bg_start: 0 + dropout_inc_epoch: 1 + bg_start: 5 transform_params: sigma_affine: 0.05 sigma_tps: 0.005 @@ -58,13 +58,17 @@ train_params: equivariance_value: 10 warp_loss: 10 bg: 10 + optimizer: 'adamw' + optimizer_params: + betas: [0.9, 0.999] + weight_decay: 0.1 train_avd_params: num_epochs: 200 num_repeats: 1 batch_size: 4 dataloader_workers: 6 - checkpoint_freq: 2 + checkpoint_freq: 10 epoch_milestones: [10, 20] lr: 1.0e-3 lambda_shift: 1 diff --git a/config/vox-768-finetune.yaml b/config/vox-768-finetune.yaml index 2f8ccae..311617b 100644 --- a/config/vox-768-finetune.yaml +++ b/config/vox-768-finetune.yaml @@ -1,6 +1,6 @@ # Use this file to finetune from a pretrained 256x256 model dataset_params: - root_dir: vox_768 + root_dir: ./video-preprocessing/vox2-768 frame_shape: null id_sampling: True augmentation_params: @@ -35,20 +35,20 @@ model_params: train_params: - num_epochs: 100 - num_repeats: 1 - epoch_milestones: [70, 90] + visualize_model: False + num_epochs: 40 + num_repeats: 4 # Higher LR seems to bring problems when finetuning lr_generator: 2.0e-5 - batch_size: 1 + batch_size: 2 scales: [1, 0.5, 0.25, 0.125] - dataloader_workers: 6 - checkpoint_freq: 1 + dataloader_workers: 8 + checkpoint_freq: 2 dropout_epoch: 0 dropout_maxp: 0.3 dropout_startp: 0.1 dropout_inc_epoch: 10 - bg_start: 0 + bg_start: 5 transform_params: sigma_affine: 0.05 sigma_tps: 0.005 @@ -58,6 +58,10 @@ train_params: equivariance_value: 10 warp_loss: 10 bg: 10 + optimizer: 'adamw' + optimizer_params: + betas: [ 0.9, 0.999 ] + weight_decay: 0.1 train_avd_params: num_epochs: 200 @@ -73,4 +77,4 @@ train_avd_params: visualizer_params: kp_size: 5 draw_border: True - colormap: 'gist_rainbow' \ No newline at end of file + colormap: 'gist_rainbow' diff --git a/frames_dataset.py b/frames_dataset.py index 69fc936..afc5a9c 100644 --- a/frames_dataset.py +++ b/frames_dataset.py @@ -67,6 +67,8 @@ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_tr random_seed=0, pairs_list=None, augmentation_params=None): self.root_dir = root_dir self.videos = os.listdir(root_dir) + if type(frame_shape) == str: + frame_shape = tuple(map(int, frame_shape.split(','))) self.frame_shape = frame_shape print(self.frame_shape) self.pairs_list = pairs_list @@ -115,7 +117,13 @@ def __getitem__(self, idx): frames = os.listdir(path) num_frames = len(frames) - frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) + # use more frames that are different from each other to speed up training + min_frames_apart = num_frames // 4 + first_frame_idx = np.random.choice(num_frames - min_frames_apart) + second_frame_idx = np.random.choice(range(first_frame_idx + min_frames_apart, num_frames)) + frame_idx = np.array([first_frame_idx, second_frame_idx]) + np.random.shuffle(frame_idx) + #frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.frame_shape is not None: resize_fn = partial(resize, output_shape=self.frame_shape) diff --git a/logger.py b/logger.py index e607253..b337975 100644 --- a/logger.py +++ b/logger.py @@ -17,12 +17,13 @@ class Logger: def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None, zfill_num=8, log_file_name='log.txt', models=()): - self.models = None + self.models = models self.loss_list = [] self.cpk_dir = log_dir self.visualizations_dir = os.path.join(log_dir, 'train-vis') if not os.path.exists(self.visualizations_dir): os.makedirs(self.visualizations_dir) + print("Visualizations will be saved in %s" % self.visualizations_dir) self.log_file = open(os.path.join(log_dir, log_file_name), 'a') self.zfill_num = zfill_num self.visualizer = Visualizer(**visualizer_params) @@ -46,9 +47,10 @@ def log_scores(self, loss_names): def visualize_rec(self, inp, out): image = self.visualizer.visualize(inp['driving'], inp['source'], out) + wandb.log({"image": [wandb.Image(image)]}) imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image) - wandb.log({"image": [wandb.Image(image)]}) + def save_cpk(self, emergent=False): cpk = {k: v.state_dict() for k, v in self.models.items()} diff --git a/modules/bg_motion_predictor.py b/modules/bg_motion_predictor.py index 446dc38..3d18299 100644 --- a/modules/bg_motion_predictor.py +++ b/modules/bg_motion_predictor.py @@ -11,6 +11,9 @@ class BGMotionPredictor(nn.Module): def __init__(self): super(BGMotionPredictor, self).__init__() self.bg_encoder = models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT) + self.preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize((256, 256)), + ]) self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) num_features = self.bg_encoder.fc.in_features self.bg_encoder.fc = nn.Linear(num_features, 6) @@ -18,8 +21,14 @@ def __init__(self): self.bg_encoder.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) def forward(self, source_image, driving_image): + + bs = source_image.shape[0] out = torch.eye(3).unsqueeze(0).repeat(bs, 1, 1).type(source_image.type()) + + source_image = self.preprocess(source_image) + driving_image = self.preprocess(driving_image) + prediction = self.bg_encoder(torch.cat([source_image, driving_image], dim=1)) out[:, :2, :] = prediction.view(bs, 2, 3) return out diff --git a/modules/keypoint_detector.py b/modules/keypoint_detector.py index 25b8193..4b8c976 100644 --- a/modules/keypoint_detector.py +++ b/modules/keypoint_detector.py @@ -15,9 +15,13 @@ def __init__(self, num_tps, **kwargs): self.fg_encoder = models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT) num_features = self.fg_encoder.fc.in_features self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2) + self.preprocess = torchvision.transforms.Compose([ + torchvision.transforms.Resize((256, 256)), + ]) def forward(self, image): + image = self.preprocess(image) fg_kp = self.fg_encoder(image) bs, _, = fg_kp.shape diff --git a/modules/util.py b/modules/util.py index 0a86991..2ccc367 100644 --- a/modules/util.py +++ b/modules/util.py @@ -150,10 +150,10 @@ def __init__(self, in_features, kernel_size, padding): def forward(self, x): out = self.norm1(x) - out = F.relu(out) + out = F.mish(out) out = self.conv1(out) out = self.norm2(out) - out = F.relu(out) + out = F.mish(out) out = self.conv2(out) out += x return out @@ -172,10 +172,10 @@ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1 self.norm = nn.InstanceNorm2d(out_features, affine=True) def forward(self, x): - out = F.interpolate(x, scale_factor=2) + out = F.interpolate(x, scale_factor=2, mode='nearest') out = self.conv(out) out = self.norm(out) - out = F.relu(out) + out = F.mish(out) return out @@ -194,7 +194,7 @@ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1 def forward(self, x): out = self.conv(x) out = self.norm(out) - out = F.relu(out) + out = F.mish(out) out = self.pool(out) return out @@ -213,7 +213,7 @@ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1 def forward(self, x): out = self.conv(x) out = self.norm(out) - out = F.relu(out) + out = F.mish(out) return out diff --git a/run.py b/run.py index 0e4fae9..f182542 100644 --- a/run.py +++ b/run.py @@ -18,6 +18,7 @@ from train_avd import train_avd from reconstruction import reconstruction import os +from torchinfo import summary import bitsandbytes as bnb optimizer_choices = { @@ -37,7 +38,7 @@ parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"]) parser.add_argument("--log_dir", default='log', help="path to log into") parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") - parser.add_argument("--optimizer_class", default="adam", choices=optimizer_choices.keys()) + parser.add_argument("--detect_anomaly", action="store_true", help="detect anomaly in autograd") opt = parser.parse_args() @@ -50,6 +51,9 @@ log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0]) log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime()) + if opt.detect_anomaly: + torch.autograd.set_detect_anomaly(True) + inpainting = InpaintingNetwork(**config['model_params']['generator_params'], **config['model_params']['common_params']) @@ -76,7 +80,17 @@ if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))): copy(opt.config, log_dir) - optimizer_class = optimizer_choices[opt.optimizer_class] + optimizer_class = optimizer_choices[config['train_params']['optimizer']] + + print("Inpainting Network:") + summary(inpainting) + print("Keypoint Detector:") + summary(kp_detector) + print("Dense Motion Network:") + summary(dense_motion_network) + if bg_predictor is not None: + print("Background Predictor:") + summary(bg_predictor) if opt.mode == 'train': print("Training...") @@ -90,3 +104,4 @@ print("Reconstruction...") #TODO: update to accelerate reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset) + diff --git a/save_model_only.py b/save_model_only.py new file mode 100644 index 0000000..af9bc2d --- /dev/null +++ b/save_model_only.py @@ -0,0 +1,68 @@ +import os +from argparse import ArgumentParser + +import yaml + +from modules.inpainting_network import InpaintingNetwork +from modules.keypoint_detector import KPDetector +from modules.bg_motion_predictor import BGMotionPredictor +from modules.dense_motion import DenseMotionNetwork +from modules.avd_network import AVDNetwork + +import torch + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('config', help='path to config') + parser.add_argument('--checkpoint', '-c', help='path to checkpoint to restore') + parser.add_argument('--target-checkpoint', '-t', default=None, help='path to checkpoint to save') + + opt = parser.parse_args() + + with open(opt.config) as f: + config = yaml.load(f) + + checkpoint = torch.load(opt.checkpoint) + print(checkpoint.keys()) + + + inpainting = InpaintingNetwork(**config['model_params']['generator_params'], + **config['model_params']['common_params']) + + + kp_detector = KPDetector(**config['model_params']['common_params']) + dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'], + **config['model_params']['dense_motion_params']) + + + bg_predictor = None + if (config['model_params']['common_params']['bg']): + bg_predictor = BGMotionPredictor() + + avd_network = None + if 'avd_network' in checkpoint: + avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'], + **config['model_params']['avd_network_params']) + + target_checkpoint = opt.target_checkpoint if opt.target_checkpoint is not None else os.path.join(opt.checkpoint.split('.')[0] + '_model_only.pth.tar') + + + inpainting.load_state_dict(checkpoint['inpainting_network']) + kp_detector.load_state_dict(checkpoint['kp_detector']) + dense_motion_network.load_state_dict(checkpoint['dense_motion_network']) + if bg_predictor is not None: + bg_predictor.load_state_dict(checkpoint['bg_predictor']) + if avd_network is not None: + avd_network.load_state_dict(checkpoint['avd_network']) + + save_dict = { + 'inpainting_network': inpainting.state_dict(), + 'kp_detector': kp_detector.state_dict(), + 'dense_motion_network': dense_motion_network.state_dict() + } + if bg_predictor is not None: + save_dict['bg_predictor'] = bg_predictor.state_dict() + if avd_network is not None: + save_dict['avd_network'] = avd_network.state_dict() + + torch.save(save_dict, target_checkpoint) \ No newline at end of file diff --git a/train.py b/train.py index aa2fd63..b76b780 100644 --- a/train.py +++ b/train.py @@ -3,12 +3,15 @@ from torch.utils.data import DataLoader from logger import Logger from modules.model import GeneratorFullModel -from torch.optim.lr_scheduler import MultiStepLR +from torch.optim.lr_scheduler import OneCycleLR from torch.nn.utils import clip_grad_norm_ from frames_dataset import DatasetRepeater from tqdm import tqdm import math from accelerate import Accelerator +from torchview import draw_graph + +torch.backends.cudnn.benchmark = True accelerator = Accelerator() @@ -16,53 +19,74 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne optimizer_class=torch.optim.Adam ): train_params = config['train_params'] + optimizer_params = config['train_params'].get('optimizer_params', {}) optimizer = optimizer_class( [{'params': list(inpainting_network.parameters()) + list(dense_motion_network.parameters()) + list(kp_detector.parameters()), 'initial_lr': train_params['lr_generator']}], - lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4) - + lr=train_params['lr_generator'], **optimizer_params) + optimizer_bg_predictor = None if bg_predictor: optimizer_bg_predictor = optimizer_class( - [{'params':bg_predictor.parameters(),'initial_lr': train_params['lr_generator']}], - lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay = 1e-4) + [{'params': bg_predictor.parameters(), 'initial_lr': train_params['lr_generator']}], + lr=train_params['lr_generator'], betas=(0.5, 0.999), weight_decay=1e-4) if checkpoint is not None: start_epoch = Logger.load_cpk( - checkpoint, inpainting_network = inpainting_network, dense_motion_network = dense_motion_network, - kp_detector = kp_detector, bg_predictor = bg_predictor, - optimizer = optimizer, optimizer_bg_predictor = optimizer_bg_predictor) + checkpoint, inpainting_network=inpainting_network, dense_motion_network=dense_motion_network, + kp_detector=kp_detector, bg_predictor=bg_predictor, + optimizer=optimizer, optimizer_bg_predictor=optimizer_bg_predictor) print('load success:', start_epoch) start_epoch += 1 else: start_epoch = 0 - scheduler_optimizer = MultiStepLR(optimizer, train_params['epoch_milestones'], gamma=0.1, - last_epoch=start_epoch - 1) - scheduler_bg_predictor = None - if bg_predictor: - scheduler_bg_predictor = MultiStepLR(optimizer_bg_predictor, train_params['epoch_milestones'], - gamma=0.1, last_epoch=start_epoch - 1) - bg_predictor, optimizer_bg_predictor = accelerator.prepare(bg_predictor, optimizer_bg_predictor) + if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) - dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, + dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=train_params['dataloader_workers'], drop_last=True) - generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network, train_params) - - bg_start = train_params['bg_start'] + scheduler_optimizer = OneCycleLR(optimizer, max_lr=train_params['lr_generator'], + total_steps=(len(dataset) // train_params['batch_size']) * train_params['num_epochs'], + last_epoch=start_epoch-1) - inpainting_network, kp_detector, dense_motion_network, optimizer, scheduler_optimizer, dataloader, generator_full = accelerator.prepare( - inpainting_network, kp_detector, dense_motion_network, optimizer, scheduler_optimizer, dataloader, generator_full) + scheduler_bg_predictor = None + if bg_predictor: + scheduler_bg_predictor = OneCycleLR(optimizer_bg_predictor, max_lr=train_params['lr_generator'], + total_steps=(len(dataset) // train_params['batch_size']) * train_params['num_epochs'], + last_epoch=start_epoch-1) + bg_predictor, optimizer_bg_predictor = accelerator.prepare(bg_predictor, optimizer_bg_predictor) + + generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network, + train_params) + bg_start = train_params['bg_start'] - with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], + inpainting_network, kp_detector, dense_motion_network, optimizer, scheduler_optimizer, dataloader, generator_full = accelerator.prepare( + inpainting_network, kp_detector, dense_motion_network, optimizer, scheduler_optimizer, dataloader, + generator_full) + + if train_params.get('visualize_model', False): + # visualize graph + sample = next(iter(dataloader)) + draw_graph(generator_full, input_data=[sample, 100], save_graph=True, directory=log_dir, graph_name='generator_full') + draw_graph(kp_detector, input_data=[sample['driving']], save_graph=True, directory=log_dir, graph_name='kp_detector') + kp_driving = kp_detector(sample['driving']) + kp_source = kp_detector(sample['source']) + bg_param = bg_predictor(sample['source'], sample['driving']) + dense_motion_param = {'source_image': sample['source'], 'kp_driving': kp_driving, 'kp_source': kp_source, 'bg_param': bg_param, + 'dropout_flag' : False, 'dropout_p' : 0.0} + dense_motion = dense_motion_network(**dense_motion_param) + draw_graph(dense_motion_network, input_data=dense_motion_param, save_graph=True, directory=log_dir, graph_name='dense_motion_network') + draw_graph(inpainting_network, input_data=[sample['source'], dense_motion], save_graph=True, directory=log_dir, graph_name='inpainting_network') + + with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq'], models=[inpainting_network, dense_motion_network, kp_detector] ) as logger: @@ -74,17 +98,17 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne accelerator.backward(loss) - clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type = math.inf) - clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type = math.inf) - if bg_predictor and epoch>=bg_start: - clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type = math.inf) - + clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type=math.inf) + clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type=math.inf) + if bg_predictor and epoch >= bg_start: + clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type=math.inf) + optimizer.step() optimizer.zero_grad() - if bg_predictor and epoch>=bg_start: + if bg_predictor and epoch >= bg_start: optimizer_bg_predictor.step() optimizer_bg_predictor.zero_grad() - + losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} lrs = { 'lr_generator': scheduler_optimizer.get_last_lr()[0], @@ -92,20 +116,23 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne } logger.log_iter(losses=losses, others=lrs) - scheduler_optimizer.step() - if bg_predictor: - scheduler_bg_predictor.step() - + scheduler_optimizer.step() + if bg_predictor: + scheduler_bg_predictor.step() + model_save = { 'inpainting_network': inpainting_network, 'dense_motion_network': dense_motion_network, 'kp_detector': kp_detector, 'optimizer': optimizer, } - if bg_predictor and epoch>=bg_start: + if bg_predictor and epoch >= bg_start: model_save['bg_predictor'] = bg_predictor model_save['optimizer_bg_predictor'] = optimizer_bg_predictor - + + accelerator.save_state(log_dir) + + logger.log_epoch(epoch, model_save, inp=x, out=generated) From f6deb7b67fb869a161a917e6237c26a0388312ae Mon Sep 17 00:00:00 2001 From: TGG Date: Wed, 12 Jul 2023 09:56:16 +0200 Subject: [PATCH 03/30] change model definitions and training --- config/vox-1024-finetune.yaml | 82 ++++++++++++++++++++++++++++++++++ config/vox-256-finetune.yaml | 14 +++--- config/vox-512-finetune.yaml | 11 ++--- config/vox-768-finetune.yaml | 14 +++--- modules/bg_motion_predictor.py | 2 +- modules/keypoint_detector.py | 2 +- save_model_only.py | 2 +- train.py | 72 +++++++++++++++++------------ 8 files changed, 150 insertions(+), 49 deletions(-) create mode 100644 config/vox-1024-finetune.yaml diff --git a/config/vox-1024-finetune.yaml b/config/vox-1024-finetune.yaml new file mode 100644 index 0000000..a2a955d --- /dev/null +++ b/config/vox-1024-finetune.yaml @@ -0,0 +1,82 @@ +# Use this file to finetune from a pretrained 256x256 model +dataset_params: + root_dir: ./video-preprocessing/vox2-768 + frame_shape: 1024,1024,3 + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 3 + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + +train_params: + num_epochs: 5 + num_repeats: 4 + # Higher LR seems to bring problems when finetuning + lr_generator: 2.0e-5 + batch_size: 1 + scales: [1, 0.5, 0.25, 0.125, 0.0625, 0.03125] + dataloader_workers: 6 + checkpoint_freq: 5 + dropout_epoch: 2 + dropout_maxp: 0.3 + dropout_startp: 0.1 + dropout_inc_epoch: 1 + bg_start: 81 + freeze_kp_detector: True + freeze_bg_predictor: True + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + warp_loss: 10 + bg: 10 + optimizer: 'adamw' + optimizer_params: + betas: [ 0.9, 0.999 ] + weight_decay: 0.1 + + +train_avd_params: + num_epochs: 200 + num_repeats: 1 + batch_size: 1 + dataloader_workers: 6 + checkpoint_freq: 1 + epoch_milestones: [140, 180] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' diff --git a/config/vox-256-finetune.yaml b/config/vox-256-finetune.yaml index daaeec4..56833c8 100644 --- a/config/vox-256-finetune.yaml +++ b/config/vox-256-finetune.yaml @@ -34,19 +34,20 @@ model_params: train_params: - num_epochs: 40 + num_epochs: 5 num_repeats: 10 - epoch_milestones: [15, 30] - lr_generator: 2.0e-4 + lr_generator: 2.0e-5 batch_size: 16 scales: [1, 0.5, 0.25, 0.125] dataloader_workers: 12 - checkpoint_freq: 50 - dropout_epoch: 2 + checkpoint_freq: 10 + dropout_epoch: 0 dropout_maxp: 0.3 dropout_startp: 0.1 dropout_inc_epoch: 10 - bg_start: 5 + bg_start: 6 + freeze_kp_detector: False + freeze_bg_predictor: True transform_params: sigma_affine: 0.05 sigma_tps: 0.005 @@ -61,6 +62,7 @@ train_params: betas: [ 0.9, 0.999 ] weight_decay: 0.1 + train_avd_params: num_epochs: 100 num_repeats: 1 diff --git a/config/vox-512-finetune.yaml b/config/vox-512-finetune.yaml index 5deeb04..a84e6a9 100644 --- a/config/vox-512-finetune.yaml +++ b/config/vox-512-finetune.yaml @@ -35,20 +35,21 @@ model_params: train_params: - num_epochs: 30 + num_epochs: 40 num_repeats: 4 - epoch_milestones: [20] # Higher LR seems to bring problems when finetuning - lr_generator: 2.0e-5 + lr_generator: 2.0e-4 batch_size: 4 - scales: [1, 0.5, 0.25, 0.125] + scales: [1, 0.5, 0.25, 0.125, 0.0625] dataloader_workers: 6 checkpoint_freq: 5 dropout_epoch: 2 dropout_maxp: 0.3 dropout_startp: 0.1 dropout_inc_epoch: 1 - bg_start: 5 + bg_start: 41 + freeze_kp_detector: True + freeze_bg_predictor: True transform_params: sigma_affine: 0.05 sigma_tps: 0.005 diff --git a/config/vox-768-finetune.yaml b/config/vox-768-finetune.yaml index 311617b..6ac37bf 100644 --- a/config/vox-768-finetune.yaml +++ b/config/vox-768-finetune.yaml @@ -1,7 +1,7 @@ # Use this file to finetune from a pretrained 256x256 model dataset_params: root_dir: ./video-preprocessing/vox2-768 - frame_shape: null + frame_shape: 768,768,3 id_sampling: True augmentation_params: flip_param: @@ -36,19 +36,21 @@ model_params: train_params: visualize_model: False - num_epochs: 40 - num_repeats: 4 + num_epochs: 80 + num_repeats: 10 # Higher LR seems to bring problems when finetuning - lr_generator: 2.0e-5 + lr_generator: 3.0e-5 batch_size: 2 - scales: [1, 0.5, 0.25, 0.125] + scales: [1, 0.5, 0.25, 0.125, 0.0625] dataloader_workers: 8 checkpoint_freq: 2 dropout_epoch: 0 dropout_maxp: 0.3 dropout_startp: 0.1 dropout_inc_epoch: 10 - bg_start: 5 + bg_start: 81 + freeze_kp_detector: True + freeze_bg_predictor: True transform_params: sigma_affine: 0.05 sigma_tps: 0.005 diff --git a/modules/bg_motion_predictor.py b/modules/bg_motion_predictor.py index 3d18299..a0a72be 100644 --- a/modules/bg_motion_predictor.py +++ b/modules/bg_motion_predictor.py @@ -12,7 +12,7 @@ def __init__(self): super(BGMotionPredictor, self).__init__() self.bg_encoder = models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT) self.preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Resize((256, 256)), + torchvision.transforms.Resize((256, 256), antialias=True), ]) self.bg_encoder.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) num_features = self.bg_encoder.fc.in_features diff --git a/modules/keypoint_detector.py b/modules/keypoint_detector.py index 4b8c976..327cc0a 100644 --- a/modules/keypoint_detector.py +++ b/modules/keypoint_detector.py @@ -16,7 +16,7 @@ def __init__(self, num_tps, **kwargs): num_features = self.fg_encoder.fc.in_features self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2) self.preprocess = torchvision.transforms.Compose([ - torchvision.transforms.Resize((256, 256)), + torchvision.transforms.Resize((256, 256), antialias=True), ]) diff --git a/save_model_only.py b/save_model_only.py index af9bc2d..dbf2f23 100644 --- a/save_model_only.py +++ b/save_model_only.py @@ -36,7 +36,7 @@ bg_predictor = None - if (config['model_params']['common_params']['bg']): + if 'bg_predictor' in checkpoint: bg_predictor = BGMotionPredictor() avd_network = None diff --git a/train.py b/train.py index b76b780..dfa2ff7 100644 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ accelerator = Accelerator() + def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset, optimizer_class=torch.optim.Adam ): @@ -44,8 +45,18 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne else: start_epoch = 0 - - + freeze_kp_detector = train_params.get('freeze_kp_detector', False) + freeze_bg_predictor = train_params.get('freeze_bg_predictor', False) + if freeze_kp_detector: + print('freeze kp detector') + kp_detector.eval() + for param in kp_detector.parameters(): + param.requires_grad = False + if freeze_bg_predictor: + print('freeze bg predictor') + bg_predictor.eval() + for param in bg_predictor.parameters(): + param.requires_grad = False if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) @@ -53,14 +64,16 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne num_workers=train_params['dataloader_workers'], drop_last=True) scheduler_optimizer = OneCycleLR(optimizer, max_lr=train_params['lr_generator'], - total_steps=(len(dataset) // train_params['batch_size']) * train_params['num_epochs'], - last_epoch=start_epoch-1) + total_steps=(len(dataset) // train_params['batch_size']) * train_params[ + 'num_epochs'], + last_epoch=start_epoch - 1) scheduler_bg_predictor = None if bg_predictor: scheduler_bg_predictor = OneCycleLR(optimizer_bg_predictor, max_lr=train_params['lr_generator'], - total_steps=(len(dataset) // train_params['batch_size']) * train_params['num_epochs'], - last_epoch=start_epoch-1) + total_steps=(len(dataset) // train_params['batch_size']) * train_params[ + 'num_epochs'], + last_epoch=start_epoch - 1) bg_predictor, optimizer_bg_predictor = accelerator.prepare(bg_predictor, optimizer_bg_predictor) generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network, @@ -75,16 +88,21 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne if train_params.get('visualize_model', False): # visualize graph sample = next(iter(dataloader)) - draw_graph(generator_full, input_data=[sample, 100], save_graph=True, directory=log_dir, graph_name='generator_full') - draw_graph(kp_detector, input_data=[sample['driving']], save_graph=True, directory=log_dir, graph_name='kp_detector') + draw_graph(generator_full, input_data=[sample, 100], save_graph=True, directory=log_dir, + graph_name='generator_full') + draw_graph(kp_detector, input_data=[sample['driving']], save_graph=True, directory=log_dir, + graph_name='kp_detector') kp_driving = kp_detector(sample['driving']) kp_source = kp_detector(sample['source']) bg_param = bg_predictor(sample['source'], sample['driving']) - dense_motion_param = {'source_image': sample['source'], 'kp_driving': kp_driving, 'kp_source': kp_source, 'bg_param': bg_param, - 'dropout_flag' : False, 'dropout_p' : 0.0} + dense_motion_param = {'source_image': sample['source'], 'kp_driving': kp_driving, 'kp_source': kp_source, + 'bg_param': bg_param, + 'dropout_flag': False, 'dropout_p': 0.0} dense_motion = dense_motion_network(**dense_motion_param) - draw_graph(dense_motion_network, input_data=dense_motion_param, save_graph=True, directory=log_dir, graph_name='dense_motion_network') - draw_graph(inpainting_network, input_data=[sample['source'], dense_motion], save_graph=True, directory=log_dir, graph_name='inpainting_network') + draw_graph(dense_motion_network, input_data=dense_motion_param, save_graph=True, directory=log_dir, + graph_name='dense_motion_network') + draw_graph(inpainting_network, input_data=[sample['source'], dense_motion], save_graph=True, directory=log_dir, + graph_name='inpainting_network') with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq'], @@ -100,14 +118,18 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type=math.inf) clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type=math.inf) - if bg_predictor and epoch >= bg_start: + if bg_predictor and epoch >= bg_start and not freeze_bg_predictor: clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type=math.inf) optimizer.step() - optimizer.zero_grad() - if bg_predictor and epoch >= bg_start: + + if bg_predictor and epoch >= bg_start and not freeze_bg_predictor: optimizer_bg_predictor.step() optimizer_bg_predictor.zero_grad() + scheduler_bg_predictor.step() + + optimizer.zero_grad() + scheduler_optimizer.step() losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} lrs = { @@ -116,23 +138,15 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne } logger.log_iter(losses=losses, others=lrs) - scheduler_optimizer.step() - if bg_predictor: - scheduler_bg_predictor.step() + model_save = { - 'inpainting_network': inpainting_network, - 'dense_motion_network': dense_motion_network, - 'kp_detector': kp_detector, + 'inpainting_network': accelerator.unwrap_model(inpainting_network), + 'dense_motion_network': accelerator.unwrap_model(dense_motion_network), + 'kp_detector': accelerator.unwrap_model(kp_detector), 'optimizer': optimizer, + 'bg_predictor': accelerator.unwrap_model(bg_predictor) if bg_predictor else None, + 'optimizer_bg_predictor': optimizer_bg_predictor } - if bg_predictor and epoch >= bg_start: - model_save['bg_predictor'] = bg_predictor - model_save['optimizer_bg_predictor'] = optimizer_bg_predictor - - accelerator.save_state(log_dir) - logger.log_epoch(epoch, model_save, inp=x, out=generated) - - From ca4f60a862e19a8e6ea6bfbce312d1c1840b0677 Mon Sep 17 00:00:00 2001 From: Philipp Haslbauer Date: Wed, 12 Jul 2023 18:58:48 +0200 Subject: [PATCH 04/30] fix scheduler resuming --- train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index dfa2ff7..09d37fc 100644 --- a/train.py +++ b/train.py @@ -66,14 +66,16 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne scheduler_optimizer = OneCycleLR(optimizer, max_lr=train_params['lr_generator'], total_steps=(len(dataset) // train_params['batch_size']) * train_params[ 'num_epochs'], - last_epoch=start_epoch - 1) + last_epoch=(len(dataset) // train_params['batch_size']) * (start_epoch - 1) + ) scheduler_bg_predictor = None if bg_predictor: scheduler_bg_predictor = OneCycleLR(optimizer_bg_predictor, max_lr=train_params['lr_generator'], total_steps=(len(dataset) // train_params['batch_size']) * train_params[ 'num_epochs'], - last_epoch=start_epoch - 1) + last_epoch=(len(dataset) // train_params['batch_size']) * (start_epoch - 1) + ) bg_predictor, optimizer_bg_predictor = accelerator.prepare(bg_predictor, optimizer_bg_predictor) generator_full = GeneratorFullModel(kp_detector, bg_predictor, dense_motion_network, inpainting_network, From 8e215fba8d9cd014f8a9a6a28bd03d1776d2ba84 Mon Sep 17 00:00:00 2001 From: TGG Date: Thu, 13 Jul 2023 01:07:27 +0200 Subject: [PATCH 05/30] add a few things --- modules/dense_motion.py | 2 ++ modules/inpainting_network.py | 15 +++++++++------ modules/util.py | 3 ++- train.py | 9 ++++++--- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/modules/dense_motion.py b/modules/dense_motion.py index ced509a..d9f54e3 100644 --- a/modules/dense_motion.py +++ b/modules/dense_motion.py @@ -29,6 +29,7 @@ def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_chann if multi_mask: up = [] self.up_nums = int(math.log(1/scale_factor, 2)) + #self.occlusion_num = 5 # usually 4, needs to be increased if layers in inpainting network are added self.occlusion_num = 4 channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)] @@ -159,6 +160,7 @@ def forward(self, source_image, kp_driving, kp_source, bg_param = None, dropout_ occlusion_map.append(torch.sigmoid(self.occlusion[i+self.occlusion_num-self.up_nums](prediction))) else: occlusion_map.append(torch.sigmoid(self.occlusion[0](prediction[-1]))) + out_dict['occlusion_map'] = occlusion_map # Multi-resolution Occlusion Masks return out_dict diff --git a/modules/inpainting_network.py b/modules/inpainting_network.py index 6b873bd..c523091 100644 --- a/modules/inpainting_network.py +++ b/modules/inpainting_network.py @@ -49,6 +49,7 @@ def occlude_input(self, inp, occlusion_map): if not self.multi_mask: if inp.shape[2] != occlusion_map.shape[2] or inp.shape[3] != occlusion_map.shape[3]: occlusion_map = F.interpolate(occlusion_map, size=inp.shape[2:], mode='bilinear',align_corners=True) + out = inp * occlusion_map return out @@ -76,13 +77,14 @@ def forward(self, source_image, dense_motion): warped_encoder_maps = [] warped_encoder_maps.append(out_ij) + for i in range(self.num_down_blocks): + + out = self.resblock[2*i](out) # e.g. 0, 2, 4, 6 + out = self.resblock[2*i+1](out) # e.g. 1, 3, 5, 7 + out = self.up_blocks[i](out) # e.g. 0, 1, 2, 3 - out = self.resblock[2*i](out) - out = self.resblock[2*i+1](out) - out = self.up_blocks[i](out) - - encode_i = encoder_map[-(i+2)] + encode_i = encoder_map[-(i+2)] # e.g. -2, -3, -4, -5 encode_ij = self.deform_input(encode_i.detach(), deformation) encode_i = self.deform_input(encode_i, deformation) @@ -120,7 +122,8 @@ def get_encode(self, driver_image, occlusion_map): encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach())) for i in range(len(self.down_blocks)): out = self.down_blocks[i](out.detach()) - out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach()) + #out_mask = self.occlude_input(out.detach(), occlusion_map[3-i].detach()) # is usually 2-i, must increase per block + out_mask = self.occlude_input(out.detach(), occlusion_map[-2-i].detach()) encoder_map.append(out_mask.detach()) return encoder_map diff --git a/modules/util.py b/modules/util.py index 2ccc367..cbdd1b1 100644 --- a/modules/util.py +++ b/modules/util.py @@ -169,6 +169,7 @@ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1 self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = nn.InstanceNorm2d(out_features, affine=True) def forward(self, x): @@ -189,7 +190,7 @@ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1 self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = nn.InstanceNorm2d(out_features, affine=True) - self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + self.pool = nn.MaxPool2d(kernel_size=(2, 2)) def forward(self, x): out = self.conv(x) diff --git a/train.py b/train.py index 09d37fc..b33b002 100644 --- a/train.py +++ b/train.py @@ -63,10 +63,13 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=train_params['dataloader_workers'], drop_last=True) + last_epoch = (len(dataset) // train_params['batch_size']) * (start_epoch - 1) + last_epoch = max(last_epoch, -1) + scheduler_optimizer = OneCycleLR(optimizer, max_lr=train_params['lr_generator'], total_steps=(len(dataset) // train_params['batch_size']) * train_params[ 'num_epochs'], - last_epoch=(len(dataset) // train_params['batch_size']) * (start_epoch - 1) + last_epoch=last_epoch ) scheduler_bg_predictor = None @@ -74,7 +77,7 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne scheduler_bg_predictor = OneCycleLR(optimizer_bg_predictor, max_lr=train_params['lr_generator'], total_steps=(len(dataset) // train_params['batch_size']) * train_params[ 'num_epochs'], - last_epoch=(len(dataset) // train_params['batch_size']) * (start_epoch - 1) + last_epoch=last_epoch ) bg_predictor, optimizer_bg_predictor = accelerator.prepare(bg_predictor, optimizer_bg_predictor) @@ -148,7 +151,7 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne 'kp_detector': accelerator.unwrap_model(kp_detector), 'optimizer': optimizer, 'bg_predictor': accelerator.unwrap_model(bg_predictor) if bg_predictor else None, - 'optimizer_bg_predictor': optimizer_bg_predictor + 'optimizer_bg_predictor': optimizer_bg_predictor, } logger.log_epoch(epoch, model_save, inp=x, out=generated) From 74c3f5da666eb5b0935c94cbc3016cc0cb6dbcce Mon Sep 17 00:00:00 2001 From: TGG Date: Thu, 13 Jul 2023 11:44:32 +0200 Subject: [PATCH 06/30] fix adding additional layers --- modules/dense_motion.py | 5 +++-- modules/inpainting_network.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/modules/dense_motion.py b/modules/dense_motion.py index d9f54e3..c274e21 100644 --- a/modules/dense_motion.py +++ b/modules/dense_motion.py @@ -12,7 +12,8 @@ class DenseMotionNetwork(nn.Module): """ def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_channels, - scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01): + scale_factor=0.25, bg = False, multi_mask = True, kp_variance=0.01, + occlusion_num = 4, **kwargs): super(DenseMotionNetwork, self).__init__() if scale_factor != 1: @@ -30,7 +31,7 @@ def __init__(self, block_expansion, num_blocks, max_features, num_tps, num_chann up = [] self.up_nums = int(math.log(1/scale_factor, 2)) #self.occlusion_num = 5 # usually 4, needs to be increased if layers in inpainting network are added - self.occlusion_num = 4 + self.occlusion_num = occlusion_num channel = [hourglass_output_size[-1]//(2**i) for i in range(self.up_nums)] for i in range(self.up_nums): diff --git a/modules/inpainting_network.py b/modules/inpainting_network.py index c523091..5c17482 100644 --- a/modules/inpainting_network.py +++ b/modules/inpainting_network.py @@ -120,10 +120,13 @@ def get_encode(self, driver_image, occlusion_map): out = self.first(driver_image) encoder_map = [] encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach())) - for i in range(len(self.down_blocks)): + n_blocks = len(self.down_blocks) + for i in range(n_blocks): out = self.down_blocks[i](out.detach()) #out_mask = self.occlude_input(out.detach(), occlusion_map[3-i].detach()) # is usually 2-i, must increase per block - out_mask = self.occlude_input(out.detach(), occlusion_map[-2-i].detach()) + #out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach()) + k = n_blocks - i - 1 + out_mask = self.occlude_input(out.detach(), occlusion_map[k].detach()) encoder_map.append(out_mask.detach()) return encoder_map From 84d5aeefb8ad39fb276a2a42dd3cda9925cc551a Mon Sep 17 00:00:00 2001 From: TGG Date: Thu, 13 Jul 2023 11:45:15 +0200 Subject: [PATCH 07/30] revert to avgpool --- modules/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/util.py b/modules/util.py index cbdd1b1..da80783 100644 --- a/modules/util.py +++ b/modules/util.py @@ -190,7 +190,7 @@ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1 self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups) self.norm = nn.InstanceNorm2d(out_features, affine=True) - self.pool = nn.MaxPool2d(kernel_size=(2, 2)) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) def forward(self, x): out = self.conv(x) From 40bd97cdc3135d8c75dd17292a46344318d3e304 Mon Sep 17 00:00:00 2001 From: TGG Date: Thu, 13 Jul 2023 11:50:27 +0200 Subject: [PATCH 08/30] gan loss --- logger.py | 7 +++++- train.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 76 insertions(+), 6 deletions(-) diff --git a/logger.py b/logger.py index b337975..ea625d6 100644 --- a/logger.py +++ b/logger.py @@ -62,7 +62,7 @@ def save_cpk(self, emergent=False): @staticmethod def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network=None, kp_detector=None, bg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None, - optimizer_avd=None): + optimizer_avd=None, discriminator=None, discriminator_optimizer=None): checkpoint = torch.load(checkpoint_path) if inpainting_network is not None: inpainting_network.load_state_dict(checkpoint['inpainting_network']) @@ -82,6 +82,11 @@ def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network=None if optimizer_avd is not None: if 'optimizer_avd' in checkpoint: optimizer_avd.load_state_dict(checkpoint['optimizer_avd']) + if discriminator is not None and 'discriminator' in checkpoint: + discriminator.load_state_dict(checkpoint['discriminator']) + if discriminator_optimizer is not None and 'optimizer_discriminator' in checkpoint: + discriminator_optimizer.load_state_dict(checkpoint['optimizer_discriminator']) + epoch = -1 if 'epoch' in checkpoint: epoch = checkpoint['epoch'] diff --git a/train.py b/train.py index b33b002..5d947dd 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,9 @@ +import torchinfo from tqdm import trange import torch from torch.utils.data import DataLoader + +from gan import MultiScaleDiscriminator, discriminator_adversarial_loss, generator_adversarial_loss from logger import Logger from modules.model import GeneratorFullModel from torch.optim.lr_scheduler import OneCycleLR @@ -16,8 +19,11 @@ accelerator = Accelerator() -def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, log_dir, dataset, - optimizer_class=torch.optim.Adam +def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_network, checkpoint, + log_dir, dataset, + optimizer_class=torch.optim.Adam, + kp_detector_checkpoint=None, + bg_predictor_checkpoint=None, ): train_params = config['train_params'] optimizer_params = config['train_params'].get('optimizer_params', {}) @@ -29,6 +35,15 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne 'initial_lr': train_params['lr_generator']}], lr=train_params['lr_generator'], **optimizer_params) + discriminator = MultiScaleDiscriminator(scales=[1], d=64) + optimizer_discriminator = optimizer_class( + [{'params': list(discriminator.parameters()), 'initial_lr': train_params['lr_discriminator']}], + lr=train_params['lr_discriminator'], **optimizer_params) + + + + torchinfo.summary(discriminator, input_size=(1, 3, 256, 256)) + optimizer_bg_predictor = None if bg_predictor: optimizer_bg_predictor = optimizer_class( @@ -39,12 +54,22 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne start_epoch = Logger.load_cpk( checkpoint, inpainting_network=inpainting_network, dense_motion_network=dense_motion_network, kp_detector=kp_detector, bg_predictor=bg_predictor, - optimizer=optimizer, optimizer_bg_predictor=optimizer_bg_predictor) + optimizer=optimizer, optimizer_bg_predictor=optimizer_bg_predictor, + discriminator=discriminator, optimizer_discriminator=optimizer_discriminator) print('load success:', start_epoch) start_epoch += 1 else: start_epoch = 0 + if kp_detector_checkpoint is not None: + kp_params = torch.load(kp_detector_checkpoint) + kp_detector.load_state_dict(kp_params['kp_detector']) + print('load kp detector success') + if bg_predictor_checkpoint is not None: + bg_params = torch.load(bg_predictor_checkpoint) + bg_predictor.load_state_dict(bg_params['bg_predictor']) + print('load bg predictor success') + freeze_kp_detector = train_params.get('freeze_kp_detector', False) freeze_bg_predictor = train_params.get('freeze_bg_predictor', False) if freeze_kp_detector: @@ -71,6 +96,13 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne 'num_epochs'], last_epoch=last_epoch ) + discriminator_scheduler = OneCycleLR(optimizer_discriminator, max_lr=train_params['lr_discriminator'], + total_steps=(len(dataset) // train_params['batch_size']) * train_params[ + 'num_epochs'], + last_epoch=last_epoch + ) + + discriminator, optimizer_discriminator, discriminator_scheduler = accelerator.prepare(discriminator, optimizer_discriminator, discriminator_scheduler) scheduler_bg_predictor = None if bg_predictor: @@ -86,6 +118,7 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne bg_start = train_params['bg_start'] + inpainting_network, kp_detector, dense_motion_network, optimizer, scheduler_optimizer, dataloader, generator_full = accelerator.prepare( inpainting_network, kp_detector, dense_motion_network, optimizer, scheduler_optimizer, dataloader, generator_full) @@ -114,11 +147,37 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne models=[inpainting_network, dense_motion_network, kp_detector] ) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): + i = 0 for x in tqdm(dataloader): losses_generator, generated = generator_full(x, epoch) + disc_loss = torch.zeros(1, device=x['driving'].device) + gen_loss = torch.zeros(1, device=x['driving'].device) + + if i % 2 == 0: + disc_pred_fake = discriminator(generated['prediction']) + disc_pred_real = discriminator(x['driving']) + for j in range(len(disc_pred_real)): # number of scales + disc_loss += discriminator_adversarial_loss(disc_pred_real[j], disc_pred_fake[j]) + else: + features_fake, fake_preds = discriminator.forward_with_features(generated['prediction']) + features_real, _ = discriminator.forward_with_features(x['driving']) + for k in range(len(fake_preds)): + gen_loss += generator_adversarial_loss(fake_preds[k]) + + losses_generator['gen'] = gen_loss + loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) + + + if i % 2 == 0: + accelerator.backward(disc_loss, retain_graph=True) + + clip_grad_norm_(discriminator.parameters(), max_norm=10, norm_type=math.inf) + optimizer_discriminator.step() + optimizer_discriminator.zero_grad() + accelerator.backward(loss) clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type=math.inf) @@ -128,6 +187,7 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne optimizer.step() + if bg_predictor and epoch >= bg_start and not freeze_bg_predictor: optimizer_bg_predictor.step() optimizer_bg_predictor.zero_grad() @@ -135,15 +195,18 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne optimizer.zero_grad() scheduler_optimizer.step() + discriminator_scheduler.step() losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} lrs = { 'lr_generator': scheduler_optimizer.get_last_lr()[0], - 'lr_bg_predictor': scheduler_bg_predictor.get_last_lr()[0] if bg_predictor else 0 + 'lr_bg_predictor': scheduler_bg_predictor.get_last_lr()[0] if bg_predictor else 0, + 'lr_discriminator': discriminator_scheduler.get_last_lr()[0] } + losses['disc'] = disc_loss.mean().detach().data.cpu().numpy() logger.log_iter(losses=losses, others=lrs) - + i += 1 model_save = { 'inpainting_network': accelerator.unwrap_model(inpainting_network), @@ -152,6 +215,8 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne 'optimizer': optimizer, 'bg_predictor': accelerator.unwrap_model(bg_predictor) if bg_predictor else None, 'optimizer_bg_predictor': optimizer_bg_predictor, + 'discriminator': accelerator.unwrap_model(discriminator), + 'optimizer_discriminator': optimizer_discriminator, } logger.log_epoch(epoch, model_save, inp=x, out=generated) From 94031eee694c014f07d87ee245ab60ed30d53fd3 Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 14 Jul 2023 00:25:53 +0200 Subject: [PATCH 09/30] add some losses --- config/vox-256-finetune.yaml | 13 +++-- modules/inpainting_network.py | 7 +-- modules/model.py | 59 ++++++++++++++++++++- run.py | 9 +++- train.py | 98 +++++++++++++++++++---------------- 5 files changed, 130 insertions(+), 56 deletions(-) diff --git a/config/vox-256-finetune.yaml b/config/vox-256-finetune.yaml index 56833c8..55eb445 100644 --- a/config/vox-256-finetune.yaml +++ b/config/vox-256-finetune.yaml @@ -34,20 +34,21 @@ model_params: train_params: - num_epochs: 5 + num_epochs: 20 num_repeats: 10 - lr_generator: 2.0e-5 - batch_size: 16 + lr_generator: 2.0e-4 + lr_discriminator: 2.0e-3 + batch_size: 4 scales: [1, 0.5, 0.25, 0.125] dataloader_workers: 12 checkpoint_freq: 10 - dropout_epoch: 0 + dropout_epoch: 3 dropout_maxp: 0.3 dropout_startp: 0.1 dropout_inc_epoch: 10 bg_start: 6 freeze_kp_detector: False - freeze_bg_predictor: True + freeze_bg_predictor: False transform_params: sigma_affine: 0.05 sigma_tps: 0.005 @@ -57,6 +58,8 @@ train_params: equivariance_value: 10 warp_loss: 10 bg: 10 + l2: 0 + optimizer: 'adamw' optimizer_params: betas: [ 0.9, 0.999 ] diff --git a/modules/inpainting_network.py b/modules/inpainting_network.py index 5c17482..31e24a2 100644 --- a/modules/inpainting_network.py +++ b/modules/inpainting_network.py @@ -121,11 +121,12 @@ def get_encode(self, driver_image, occlusion_map): encoder_map = [] encoder_map.append(self.occlude_input(out.detach(), occlusion_map[-1].detach())) n_blocks = len(self.down_blocks) + + # len(occlusion_map) = n_blocks + 1, because of the original image size + for i in range(n_blocks): out = self.down_blocks[i](out.detach()) - #out_mask = self.occlude_input(out.detach(), occlusion_map[3-i].detach()) # is usually 2-i, must increase per block - #out_mask = self.occlude_input(out.detach(), occlusion_map[2-i].detach()) - k = n_blocks - i - 1 + k = -(i+2) # reverse index out_mask = self.occlude_input(out.detach(), occlusion_map[k].detach()) encoder_map.append(out_mask.detach()) diff --git a/modules/model.py b/modules/model.py index 335e348..d8a449a 100644 --- a/modules/model.py +++ b/modules/model.py @@ -1,10 +1,15 @@ +import math +import random + import torchvision +from facenet_pytorch.models.mtcnn import prewhiten from torch import nn import torch import torch.nn.functional as F from modules.util import AntiAliasInterpolation2d, TPS from torchvision import models import numpy as np +from facenet_pytorch import MTCNN, InceptionResnetV1 class Vgg19(torch.nn.Module): @@ -77,7 +82,8 @@ class GeneratorFullModel(torch.nn.Module): Merge all generator related updates into single model for better multi-gpu usage """ - def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, *kwargs): + def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, + *kwargs): super(GeneratorFullModel, self).__init__() self.kp_extractor = kp_extractor self.inpainting_network = inpainting_network @@ -106,6 +112,19 @@ def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_ if torch.cuda.is_available(): self.vgg = self.vgg.cuda() + if self.loss_weights.get('id', 0) > 0: + self.id_recognition_model = InceptionResnetV1(pretrained='vggface2').eval() + self.id_recognition_model.requires_grad_(False) + self.mtcnn = MTCNN() + self.mtcnn.requires_grad_(False) + if torch.cuda.is_available(): + self.id_recognition_model = self.id_recognition_model.cuda() + self.mtcnn = self.mtcnn.cuda() + + else: + self.id_recognition_model = None + + def forward(self, x, epoch): kp_source = self.kp_extractor(x['source']) @@ -126,9 +145,11 @@ def forward(self, x, epoch): dense_motion = self.dense_motion_network(source_image=x['source'], kp_driving=kp_driving, kp_source=kp_source, bg_param = bg_param, dropout_flag = dropout_flag, dropout_p = dropout_p) + generated = self.inpainting_network(x['source'], dense_motion) generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + loss_values = {} pyramide_real = self.pyramid(x['driving']) @@ -180,4 +201,40 @@ def forward(self, x, epoch): value = torch.abs(eye - value).mean() loss_values['bg'] = self.loss_weights['bg'] * value + # l2 loss + if self.loss_weights['l2'] != 0: + loss_values['l2'] = torch.abs(generated['prediction'] - x['driving']).mean() + loss_values['l2'] = self.loss_weights['l2'] * loss_values['l2'] + + # huber loss + if self.loss_weights['huber'] != 0: + loss_values['huber'] = F.smooth_l1_loss(generated['prediction'], x['driving']) + loss_values['huber'] = self.loss_weights['huber'] * loss_values['huber'] + + # id loss + if self.id_recognition_model and self.loss_weights['id'] != 0: + try: + driving_preprocessed = x['driving'] * 255 + driving_preprocessed = driving_preprocessed.permute(0, 2, 3, 1) + driving_preprocessed = driving_preprocessed.to(torch.float16) + driving_preprocessed = self.mtcnn(driving_preprocessed) + generated_preprocessed = generated['prediction'] * 255 + generated_preprocessed = generated_preprocessed.permute(0, 2, 3, 1) + generated_preprocessed = generated_preprocessed.to(torch.float16) + + generated_preprocessed = self.mtcnn(generated_preprocessed) + except Exception as e: + print('MTCNN failed, using bilinear interpolation') + print(e) + driving_preprocessed = prewhiten(x['driving']) + driving_preprocessed = torch.nn.functional.interpolate(driving_preprocessed, size=(160, 160), mode='bilinear', align_corners=True) + generated_preprocessed = prewhiten(generated['prediction']) + generated_preprocessed = torch.nn.functional.interpolate(generated_preprocessed, size=(160, 160), mode='bilinear', align_corners=True) + id_real = self.id_recognition_model(driving_preprocessed) + id_generated = self.id_recognition_model(generated_preprocessed) + #cosine + value = 1 - torch.nn.functional.cosine_similarity(id_real, id_generated, dim=1) + loss_values['id'] = self.loss_weights['id'] * value.mean() + + return loss_values, generated diff --git a/run.py b/run.py index f182542..ddb5c17 100644 --- a/run.py +++ b/run.py @@ -38,6 +38,8 @@ parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"]) parser.add_argument("--log_dir", default='log', help="path to log into") parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") + parser.add_argument("--kp_detector", default=None, help="path to kp_detector checkpoint to restore") + parser.add_argument("--bg_predictor", default=None, help="path to bg_predictor checkpoint to restore") parser.add_argument("--detect_anomaly", action="store_true", help="detect anomaly in autograd") @@ -94,8 +96,11 @@ if opt.mode == 'train': print("Training...") - train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset, - optimizer_class=optimizer_class) + train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, + log_dir, dataset, + optimizer_class=optimizer_class, + kp_detector_checkpoint=opt.kp_detector, + bg_predictor_checkpoint=opt.bg_predictor) elif opt.mode == 'train_avd': print("Training Animation via Disentaglement...") train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, diff --git a/train.py b/train.py index 5d947dd..d0d2aa6 100644 --- a/train.py +++ b/train.py @@ -141,72 +141,80 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne graph_name='dense_motion_network') draw_graph(inpainting_network, input_data=[sample['source'], dense_motion], save_graph=True, directory=log_dir, graph_name='inpainting_network') - + model_list = [inpainting_network, dense_motion_network, discriminator] + if bg_predictor: + model_list.append(bg_predictor) + if not freeze_kp_detector: + model_list.append(kp_detector) with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq'], - models=[inpainting_network, dense_motion_network, kp_detector] + models=model_list, ) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): i = 0 for x in tqdm(dataloader): - losses_generator, generated = generator_full(x, epoch) - disc_loss = torch.zeros(1, device=x['driving'].device) - gen_loss = torch.zeros(1, device=x['driving'].device) + with (accelerator.accumulate(generator_full), accelerator.accumulate(discriminator), + accelerator.accumulate(inpainting_network), accelerator.accumulate(dense_motion_network), + accelerator.accumulate(kp_detector), accelerator.accumulate(bg_predictor)): + losses_generator, generated = generator_full(x, epoch) + disc_loss = torch.zeros(1, device=x['driving'].device) + gen_loss = torch.zeros(1, device=x['driving'].device) + + if i % 2 == 0: + disc_pred_fake = discriminator(generated['prediction']) + disc_pred_real = discriminator(x['driving']) + for j in range(len(disc_pred_real)): # number of scales + disc_loss += discriminator_adversarial_loss(disc_pred_real[j], disc_pred_fake[j]) + else: + features_fake, fake_preds = discriminator.forward_with_features(generated['prediction']) + features_real, _ = discriminator.forward_with_features(x['driving']) + for k in range(len(fake_preds)): + gen_loss += generator_adversarial_loss(fake_preds[k]) - if i % 2 == 0: - disc_pred_fake = discriminator(generated['prediction']) - disc_pred_real = discriminator(x['driving']) - for j in range(len(disc_pred_real)): # number of scales - disc_loss += discriminator_adversarial_loss(disc_pred_real[j], disc_pred_fake[j]) - else: - features_fake, fake_preds = discriminator.forward_with_features(generated['prediction']) - features_real, _ = discriminator.forward_with_features(x['driving']) - for k in range(len(fake_preds)): - gen_loss += generator_adversarial_loss(fake_preds[k]) + losses_generator['gen'] = gen_loss - losses_generator['gen'] = gen_loss + loss_values = [val.mean() for val in losses_generator.values()] + loss = sum(loss_values) - loss_values = [val.mean() for val in losses_generator.values()] - loss = sum(loss_values) + if i % 2 == 0: + accelerator.backward(disc_loss, retain_graph=True) - if i % 2 == 0: - accelerator.backward(disc_loss, retain_graph=True) + clip_grad_norm_(discriminator.parameters(), max_norm=10, norm_type=math.inf) + optimizer_discriminator.step() + optimizer_discriminator.zero_grad() - clip_grad_norm_(discriminator.parameters(), max_norm=10, norm_type=math.inf) - optimizer_discriminator.step() - optimizer_discriminator.zero_grad() + accelerator.backward(loss) - accelerator.backward(loss) + clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type=math.inf) + clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type=math.inf) + if bg_predictor and epoch >= bg_start and not freeze_bg_predictor: + clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type=math.inf) - clip_grad_norm_(kp_detector.parameters(), max_norm=10, norm_type=math.inf) - clip_grad_norm_(dense_motion_network.parameters(), max_norm=10, norm_type=math.inf) - if bg_predictor and epoch >= bg_start and not freeze_bg_predictor: - clip_grad_norm_(bg_predictor.parameters(), max_norm=10, norm_type=math.inf) + optimizer.step() - optimizer.step() + if bg_predictor and epoch >= bg_start and not freeze_bg_predictor: + optimizer_bg_predictor.step() + optimizer_bg_predictor.zero_grad() + scheduler_bg_predictor.step() - if bg_predictor and epoch >= bg_start and not freeze_bg_predictor: - optimizer_bg_predictor.step() - optimizer_bg_predictor.zero_grad() - scheduler_bg_predictor.step() + scheduler_optimizer.step() + optimizer.zero_grad() - optimizer.zero_grad() - scheduler_optimizer.step() - discriminator_scheduler.step() + discriminator_scheduler.step() - losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} - lrs = { - 'lr_generator': scheduler_optimizer.get_last_lr()[0], - 'lr_bg_predictor': scheduler_bg_predictor.get_last_lr()[0] if bg_predictor else 0, - 'lr_discriminator': discriminator_scheduler.get_last_lr()[0] - } - losses['disc'] = disc_loss.mean().detach().data.cpu().numpy() - logger.log_iter(losses=losses, others=lrs) + losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} + lrs = { + 'lr_generator': scheduler_optimizer.get_last_lr()[0], + 'lr_bg_predictor': scheduler_bg_predictor.get_last_lr()[0] if bg_predictor else 0, + 'lr_discriminator': discriminator_scheduler.get_last_lr()[0] + } + losses['disc'] = disc_loss.mean().detach().data.cpu().numpy() + logger.log_iter(losses=losses, others=lrs) - i += 1 + i += 1 model_save = { 'inpainting_network': accelerator.unwrap_model(inpainting_network), From 6dad8b4de8a93b587611062e487eedc46e569b54 Mon Sep 17 00:00:00 2001 From: TGG Date: Wed, 19 Jul 2023 16:44:00 +0200 Subject: [PATCH 10/30] add lots of stuff --- config/vox-1024-deeper.yaml | 92 ++++++++++++++++++ config/vox-256-deeper.yaml | 86 +++++++++++++++++ config/vox-512-deeper.yaml | 94 ++++++++++++++++++ config/vox-512-finetune.yaml | 4 +- config/vox-768-deeper.yaml | 94 ++++++++++++++++++ demo.py | 2 +- frames_dataset.py | 9 +- gan.py | 176 ++++++++++++++++++++++++++++++++++ modules/inpainting_network.py | 38 +++++--- modules/model.py | 68 ++++++++----- run.py | 8 +- save_model_only.py | 30 ++++-- train.py | 109 ++++++++++++++------- utils.py | 51 ++++++++++ 14 files changed, 774 insertions(+), 87 deletions(-) create mode 100644 config/vox-1024-deeper.yaml create mode 100644 config/vox-256-deeper.yaml create mode 100644 config/vox-512-deeper.yaml create mode 100644 config/vox-768-deeper.yaml create mode 100644 gan.py diff --git a/config/vox-1024-deeper.yaml b/config/vox-1024-deeper.yaml new file mode 100644 index 0000000..3bd3e5d --- /dev/null +++ b/config/vox-1024-deeper.yaml @@ -0,0 +1,92 @@ +dataset_params: + root_dir: ./vox512_filtered_webp + frame_shape: 1024,1024,3 + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 4 + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 # might make sense to set to 0.5 because of the additional occlusion (4=>5) + occlusion_num: 5 + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + + +train_params: + num_epochs: 80 + num_repeats: 2 + lr_generator: 2.0e-5 + lr_discriminator: 2.0e-5 + batch_size: 1 + scales: [1, 0.5, 0.25, 0.125, 0.0625, 0.03125] + dataloader_workers: 8 + checkpoint_freq: 5 + dropout_epoch: 0 + dropout_maxp: 0.3 + dropout_startp: 0.1 + dropout_inc_epoch: 10 + bg_start: 101 + freeze_kp_detector: True + freeze_bg_predictor: True + freeze_dense_motion: False + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [5, 5, 5, 5, 5] + equivariance_value: 10 + warp_loss: 10 + bg: 0 + l2: 0 + id: 0.1 + huber: 0 + generator_gan: 10 + generator_feat_match: 100 + discriminator_gan: 10 + optimizer: 'adamw' + optimizer_params: + betas: [ 0.9, 0.999 ] + weight_decay: 1.0e-3 + scheduler: 'onecycle' + scheduler_params: + pct_start: 0.01 + +train_avd_params: + num_epochs: 100 + num_repeats: 1 + batch_size: 8 + dataloader_workers: 6 + checkpoint_freq: 1 + epoch_milestones: [10, 20] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' \ No newline at end of file diff --git a/config/vox-256-deeper.yaml b/config/vox-256-deeper.yaml new file mode 100644 index 0000000..047a65c --- /dev/null +++ b/config/vox-256-deeper.yaml @@ -0,0 +1,86 @@ +dataset_params: + root_dir: ./vox512_webp + frame_shape: 256,256,3 + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 4 + dense_motion_params: + block_expansion: 64 + max_features: 512 + num_blocks: 5 + scale_factor: 0.25 # might make sense to set to 0.5 because of the additional occlusion (4=>5) + occlusion_num: 5 + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + +train_params: + num_epochs: 10 + num_repeats: 3 + lr_generator: 2.0e-4 + lr_discriminator: 2.0e-4 + batch_size: 4 + scales: [1, 0.5, 0.25, 0.125] + dataloader_workers: 8 + checkpoint_freq: 10 + dropout_epoch: 35 + dropout_maxp: 0.3 + dropout_startp: 0.1 + dropout_inc_epoch: 10 + bg_start: 101 + freeze_kp_detector: True + freeze_bg_predictor: True + freeze_dense_motion: False + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + warp_loss: 10 + bg: 0 + l2: 0 + id: 0 + huber: 1 + optimizer: 'adamw' + optimizer_params: + betas: [ 0.9, 0.999 ] + weight_decay: 1.0e-3 + + +train_avd_params: + num_epochs: 100 + num_repeats: 1 + batch_size: 8 + dataloader_workers: 6 + checkpoint_freq: 1 + epoch_milestones: [10, 20] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' \ No newline at end of file diff --git a/config/vox-512-deeper.yaml b/config/vox-512-deeper.yaml new file mode 100644 index 0000000..aa791a8 --- /dev/null +++ b/config/vox-512-deeper.yaml @@ -0,0 +1,94 @@ +dataset_params: + root_dir: ./vox512_filtered_webp + frame_shape: 512,512,3 + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 4 + concat_encode: False + use_skip_blocks: True + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 # might make sense to set to 0.5 because of the additional occlusion (4=>5) + occlusion_num: 5 + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + +train_params: + num_epochs: 80 + num_repeats: 2 + lr_generator: 2.0e-4 + lr_discriminator: 2.0e-4 + batch_size: 2 + scales: [1, 0.5, 0.25, 0.125] + dataloader_workers: 8 + checkpoint_freq: 5 + dropout_epoch: 0 + dropout_maxp: 0.3 + dropout_startp: 0.1 + dropout_inc_epoch: 10 + bg_start: 101 + freeze_kp_detector: True + freeze_bg_predictor: True + freeze_dense_motion: False + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [5, 5, 5, 5, 5] + equivariance_value: 10 + warp_loss: 10 + bg: 0 + l2: 0 + id: 1 + huber: 0 + generator_gan: 10 + generator_feat_match: 100 + discriminator_gan: 10 + + optimizer: 'adamw' + optimizer_params: + betas: [ 0.9, 0.999 ] + weight_decay: 1.0e-3 + scheduler: 'onecycle' + scheduler_params: + pct_start: 0.1 + +train_avd_params: + num_epochs: 100 + num_repeats: 1 + batch_size: 8 + dataloader_workers: 6 + checkpoint_freq: 1 + epoch_milestones: [10, 20] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' \ No newline at end of file diff --git a/config/vox-512-finetune.yaml b/config/vox-512-finetune.yaml index a84e6a9..4574848 100644 --- a/config/vox-512-finetune.yaml +++ b/config/vox-512-finetune.yaml @@ -39,7 +39,8 @@ train_params: num_repeats: 4 # Higher LR seems to bring problems when finetuning lr_generator: 2.0e-4 - batch_size: 4 + lr_discriminator: 2.0e-3 + batch_size: 2 scales: [1, 0.5, 0.25, 0.125, 0.0625] dataloader_workers: 6 checkpoint_freq: 5 @@ -59,6 +60,7 @@ train_params: equivariance_value: 10 warp_loss: 10 bg: 10 + l2: 0 optimizer: 'adamw' optimizer_params: betas: [0.9, 0.999] diff --git a/config/vox-768-deeper.yaml b/config/vox-768-deeper.yaml new file mode 100644 index 0000000..9f19e8b --- /dev/null +++ b/config/vox-768-deeper.yaml @@ -0,0 +1,94 @@ +name: vox-768-deeper +dataset_params: + root_dir: ./video-preprocessing/vox2-768 + frame_shape: 768,768,3 + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 4 + concat_encode: False + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 # might make sense to set to 0.5 because of the additional occlusion (4=>5) + occlusion_num: 5 + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + + +train_params: + num_epochs: 20 + num_repeats: 2 + lr_generator: 2.0e-4 + lr_discriminator: 2.0e-4 + batch_size: 1 + scales: [1, 0.5, 0.25, 0.125, 0.0625] + dataloader_workers: 8 + checkpoint_freq: 5 + dropout_epoch: 0 + dropout_maxp: 0.3 + dropout_startp: 0.1 + dropout_inc_epoch: 10 + bg_start: 101 + freeze_kp_detector: True + freeze_bg_predictor: True + freeze_dense_motion: False + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [5, 5, 5, 5, 5] + equivariance_value: 10 + warp_loss: 10 + bg: 0 + l2: 0 + id: 0.1 + huber: 0 + generator_gan: 10 + generator_feat_match: 100 + discriminator_gan: 10 + optimizer: 'adamw' + optimizer_params: + betas: [ 0.9, 0.999 ] + weight_decay: 1.0e-3 + scheduler: 'onecycle' + scheduler_params: + pct_start: 0.01 + +train_avd_params: + num_epochs: 100 + num_repeats: 1 + batch_size: 8 + dataloader_workers: 6 + checkpoint_freq: 1 + epoch_milestones: [10, 20] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' \ No newline at end of file diff --git a/demo.py b/demo.py index 3042b75..21f7a63 100644 --- a/demo.py +++ b/demo.py @@ -185,7 +185,7 @@ def read_and_resize_frames_backward(video_path, img_shape, end_frame): parser.add_argument("--mode", default='relative', choices=['standard', 'relative', 'avd'], help="Animate mode: ['standard', 'relative', 'avd'], when use the relative mode to animate a face, use '--find_best_frame' can get better quality result") - parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true", + parser.add_argument("-fb", "--find_best_frame", dest="find_best_frame", action="store_true", help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)") parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") diff --git a/frames_dataset.py b/frames_dataset.py index afc5a9c..edeaf71 100644 --- a/frames_dataset.py +++ b/frames_dataset.py @@ -1,4 +1,6 @@ import os + +from albumentations import AdvancedBlur from skimage import io, img_as_float32 from skimage.color import gray2rgb from sklearn.model_selection import train_test_split @@ -99,6 +101,7 @@ def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_tr else: self.transform = None + def __len__(self): return len(self.videos) @@ -118,7 +121,7 @@ def __getitem__(self, idx): frames = os.listdir(path) num_frames = len(frames) # use more frames that are different from each other to speed up training - min_frames_apart = num_frames // 4 + min_frames_apart = num_frames // 3 first_frame_idx = np.random.choice(num_frames - min_frames_apart) second_frame_idx = np.random.choice(range(first_frame_idx + min_frames_apart, num_frames)) frame_idx = np.array([first_frame_idx, second_frame_idx]) @@ -150,7 +153,9 @@ def __getitem__(self, idx): out = {} if self.is_train: - source = np.array(video_array[0], dtype='float32') + source = video_array[0] + source = np.array(source, dtype='float32') + driving = np.array(video_array[1], dtype='float32') out['driving'] = driving.transpose((2, 0, 1)) diff --git a/gan.py b/gan.py new file mode 100644 index 0000000..a6e2e49 --- /dev/null +++ b/gan.py @@ -0,0 +1,176 @@ + +import torch +import torch.nn as nn + +def generator_adversarial_loss(fake_preds): + #clipped_fake_preds = torch.clamp(fake_preds, -1.0, 1.0) + # todo do I have to apply the gen loss also across all layers? + return -torch.mean(fake_preds) + +def discriminator_adversarial_loss(real_preds, fake_preds, label_smoothing_stddev=0.1): + smoothed_real_label = torch.normal(torch.tensor(1.0), torch.tensor(label_smoothing_stddev), size=real_preds.shape).to(real_preds.device) + real_loss = torch.mean(nn.ReLU()(smoothed_real_label - real_preds)) + fake_loss = torch.mean(nn.ReLU()(smoothed_real_label + fake_preds)) + return real_loss + fake_loss + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, activation1=nn.Mish(), activation2=nn.Mish()): + super(ConvBlock, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 1, padding="same") + + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding="same") + self.conv2 = nn.utils.spectral_norm(self.conv2) + self.activation1 = activation1 + self.activation2 = activation2 + + self.conv_skip = nn.Conv2d(in_channels, out_channels, 1, 1) + + + def forward(self, x): + input_ = self.conv_skip(x) + x = self.conv(x) + x = self.activation1(x) + x = self.conv2(x) + if self.activation2 is not None: + x = self.activation2(x) + return x + input_ + +def normal_init(m, mean, std): + if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d): + m.weight.data.normal_(mean, std) + m.bias.data.zero_() + +class Discriminator(torch.nn.Module): + # initializers + def __init__(self, d=32, scale=1): + super(Discriminator, self).__init__() + self.conv1 = torch.nn.Sequential( + ConvBlock(3, d, 3), + nn.GroupNorm(1, d), + #nn.InstanceNorm2d(d), + nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False), + ) + + self.conv2 = torch.nn.Sequential( + ConvBlock(d, d * 2, 3), + nn.GroupNorm(1, d * 2), + #nn.InstanceNorm2d(d * 2), + nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False), + ) + + self.conv3 = torch.nn.Sequential( + ConvBlock(d * 2, d * 4, 3), + nn.GroupNorm(1, d * 4), + #nn.InstanceNorm2d(d * 4), + nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False), + ) + + self.conv4 = torch.nn.Sequential( + ConvBlock(d * 4, d * 8, 3), + nn.GroupNorm(1, d * 8), + #nn.InstanceNorm2d(d * 8), + nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False), + ) + + self.conv5 = torch.nn.Sequential( + ConvBlock(d * 8, d * 16, 3), + ) + + assert scale in [1, 2, 4, 8, 16], "Scale should be 1, 2, 4, 8 or 16" + + self.scale = scale + + self.scaler = torch.nn.functional.interpolate + + self.weight_init(mean=0.0, std=0.02) + + # weight_init + def weight_init(self, mean, std): + for m in self._modules: + normal_init(self._modules[m], mean, std) + + # forward method + def forward(self, x): + x = self.scaler(x, scale_factor=self.scale, mode='bilinear', align_corners=False) + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + + return x + +class DiscriminatorWithFeatures(Discriminator): + def forward_with_features(self, x): + x = self.scaler(x, scale_factor=self.scale, mode='bilinear', align_corners=False) + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + x5 = self.conv5(x4) + + assert not torch.isnan(x1).any(), "x1 contains NaN values" + assert not torch.isnan(x2).any(), "x2 contains NaN values" + assert not torch.isnan(x3).any(), "x3 contains NaN values" + assert not torch.isnan(x4).any(), "x4 contains NaN values" + assert not torch.isnan(x5).any(), "x5 contains NaN values" + + return [x1, x2, x3,x4], x5 + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self, d=32, scales=(1, 2)): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorWithFeatures(d, scale) for scale in scales + ]) + self.scales = scales + + def forward(self, x): + results = [] + for i, _ in enumerate(self.scales): + results.append(self.discriminators[i].forward(x)) + return results + + def forward_with_features(self, x): + x1x2x3x4 = [] + x5 = [] + for i, _ in enumerate(self.scales): + x1x2x3x4_, x5_ = self.discriminators[i].forward_with_features(x) + + x1x2x3x4.append(x1x2x3x4_) + x5.append(x5_) + + return x1x2x3x4, x5 + +def compute_gradient_penalty(discriminator, real_samples, fake_samples, device='cuda'): + batch_size = real_samples.size(0) + alpha = torch.rand(batch_size, 1, 1, 1).to(device) + interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True) + + d_interpolates = discriminator(interpolates) + fake = torch.ones_like(d_interpolates).to(device) + + gradients = torch.autograd.grad( + outputs=d_interpolates, + inputs=interpolates, + grad_outputs=fake, + create_graph=True, + retain_graph=True, + only_inputs=True + )[0] + + gradients = gradients.view(batch_size, -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + + return gradient_penalty +def weak_feature_matching_loss(predicted_features, target_features, start_layer=2): + num_layers = len(predicted_features) + loss = 0 + for i in range(start_layer, num_layers): + + num_elements = torch.prod(torch.tensor(predicted_features[i].shape[1:])) + layer_loss = nn.L1Loss()(predicted_features[i], target_features[i]) / num_elements + assert not torch.isnan(layer_loss).any(), f"layer_loss at layer {i} contains NaN values" + assert torch.isfinite(layer_loss).all(), f"layer_loss at layer {i} contains Inf values" + loss += layer_loss + return loss diff --git a/modules/inpainting_network.py b/modules/inpainting_network.py index 31e24a2..73b101f 100644 --- a/modules/inpainting_network.py +++ b/modules/inpainting_network.py @@ -9,21 +9,31 @@ class InpaintingNetwork(nn.Module): """ Inpaint the missing regions and reconstruct the Driving image. """ - def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, **kwargs): + def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, + concat_encode=True, use_skip_blocks=False, + **kwargs): super(InpaintingNetwork, self).__init__() self.num_down_blocks = num_down_blocks self.multi_mask = multi_mask self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) + self.concat_encode = concat_encode down_blocks = [] up_blocks = [] - resblock = [] + resblock = []# + skip_blocks = [] for i in range(num_down_blocks): in_features = min(max_features, block_expansion * (2 ** i)) out_features = min(max_features, block_expansion * (2 ** (i + 1))) down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) - decoder_in_feature = out_features * 2 + if use_skip_blocks: + skip_blocks.append(nn.Conv2d(in_features, out_features, kernel_size=(1, 1))) + if concat_encode: + decoder_in_feature = out_features * 2 + else: + decoder_in_feature = out_features + if i==num_down_blocks-1: decoder_in_feature = out_features up_blocks.append(UpBlock2d(decoder_in_feature, in_features, kernel_size=(3, 3), padding=(1, 1))) @@ -32,6 +42,10 @@ def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, self.down_blocks = nn.ModuleList(down_blocks) self.up_blocks = nn.ModuleList(up_blocks[::-1]) self.resblock = nn.ModuleList(resblock[::-1]) + if skip_blocks: + self.skip_blocks = nn.ModuleList(skip_blocks[::-1]) + else: + self.skip_blocks = None self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) self.num_channels = num_channels @@ -58,6 +72,8 @@ def forward(self, source_image, dense_motion): encoder_map = [out] for i in range(len(self.down_blocks)): out = self.down_blocks[i](out) + if self.skip_blocks: + out = self.skip_blocks[i](out) encoder_map.append(out) output_dict = {} @@ -68,15 +84,10 @@ def forward(self, source_image, dense_motion): output_dict['occlusion_map'] = occlusion_map deformation = dense_motion['deformation'] - out_ij = self.deform_input(out.detach(), deformation) out = self.deform_input(out, deformation) - - out_ij = self.occlude_input(out_ij, occlusion_map[0].detach()) out = self.occlude_input(out, occlusion_map[0]) - warped_encoder_maps = [] - warped_encoder_maps.append(out_ij) - + warped_encoder_maps = [out.detach()] for i in range(self.num_down_blocks): @@ -85,20 +96,21 @@ def forward(self, source_image, dense_motion): out = self.up_blocks[i](out) # e.g. 0, 1, 2, 3 encode_i = encoder_map[-(i+2)] # e.g. -2, -3, -4, -5 - encode_ij = self.deform_input(encode_i.detach(), deformation) encode_i = self.deform_input(encode_i, deformation) occlusion_ind = 0 if self.multi_mask: occlusion_ind = i+1 - encode_ij = self.occlude_input(encode_ij, occlusion_map[occlusion_ind].detach()) encode_i = self.occlude_input(encode_i, occlusion_map[occlusion_ind]) - warped_encoder_maps.append(encode_ij) + warped_encoder_maps.append(encode_i.detach()) if(i==self.num_down_blocks-1): break - out = torch.cat([out, encode_i], 1) + if self.concat_encode: + out = torch.cat([out, encode_i], 1) + else: + out = out + encode_i deformed_source = self.deform_input(source_image, deformation) output_dict["deformed"] = deformed_source diff --git a/modules/model.py b/modules/model.py index d8a449a..683b083 100644 --- a/modules/model.py +++ b/modules/model.py @@ -83,6 +83,7 @@ class GeneratorFullModel(torch.nn.Module): """ def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_network, train_params, + discriminator=None, *kwargs): super(GeneratorFullModel, self).__init__() self.kp_extractor = kp_extractor @@ -115,24 +116,28 @@ def __init__(self, kp_extractor, bg_predictor, dense_motion_network, inpainting_ if self.loss_weights.get('id', 0) > 0: self.id_recognition_model = InceptionResnetV1(pretrained='vggface2').eval() self.id_recognition_model.requires_grad_(False) - self.mtcnn = MTCNN() - self.mtcnn.requires_grad_(False) + #self.mtcnn = MTCNN() + #self.mtcnn.requires_grad_(False) if torch.cuda.is_available(): self.id_recognition_model = self.id_recognition_model.cuda() - self.mtcnn = self.mtcnn.cuda() + #self.mtcnn = self.mtcnn.cuda() else: self.id_recognition_model = None + if self.loss_weights.get('gan', 0) > 0: + self.discriminator = discriminator + if torch.cuda.is_available(): + self.discriminator = self.discriminator.cuda() + - def forward(self, x, epoch): + def forward(self, x, epoch, step=None): kp_source = self.kp_extractor(x['source']) kp_driving = self.kp_extractor(x['driving']) bg_param = None - if self.bg_predictor: - if(epoch>=self.bg_start): - bg_param = self.bg_predictor(x['source'], x['driving']) + if self.bg_predictor and epoch>=self.bg_start: + bg_param = self.bg_predictor(x['source'], x['driving']) if(epoch>=self.dropout_epoch): dropout_flag = False @@ -213,28 +218,43 @@ def forward(self, x, epoch): # id loss if self.id_recognition_model and self.loss_weights['id'] != 0: - try: - driving_preprocessed = x['driving'] * 255 - driving_preprocessed = driving_preprocessed.permute(0, 2, 3, 1) - driving_preprocessed = driving_preprocessed.to(torch.float16) - driving_preprocessed = self.mtcnn(driving_preprocessed) - generated_preprocessed = generated['prediction'] * 255 - generated_preprocessed = generated_preprocessed.permute(0, 2, 3, 1) - generated_preprocessed = generated_preprocessed.to(torch.float16) - - generated_preprocessed = self.mtcnn(generated_preprocessed) - except Exception as e: - print('MTCNN failed, using bilinear interpolation') - print(e) - driving_preprocessed = prewhiten(x['driving']) - driving_preprocessed = torch.nn.functional.interpolate(driving_preprocessed, size=(160, 160), mode='bilinear', align_corners=True) - generated_preprocessed = prewhiten(generated['prediction']) - generated_preprocessed = torch.nn.functional.interpolate(generated_preprocessed, size=(160, 160), mode='bilinear', align_corners=True) + #try: + # driving_preprocessed = x['driving'] * 255 + # driving_preprocessed = driving_preprocessed.permute(0, 2, 3, 1) + # driving_preprocessed = driving_preprocessed.to(torch.float16) + # driving_preprocessed = self.mtcnn(driving_preprocessed) + # generated_preprocessed = generated['prediction'] * 255 + # generated_preprocessed = generated_preprocessed.permute(0, 2, 3, 1) + # generated_preprocessed = generated_preprocessed.to(torch.float16) + + # generated_preprocessed = self.mtcnn(generated_preprocessed) + #except Exception as e: + #print('MTCNN failed, using bilinear interpolation') + #print(e) + driving_preprocessed = prewhiten(x['driving']) + driving_preprocessed = torch.nn.functional.interpolate(driving_preprocessed, size=(160, 160), mode='bilinear', align_corners=True) + generated_preprocessed = prewhiten(generated['prediction']) + generated_preprocessed = torch.nn.functional.interpolate(generated_preprocessed, size=(160, 160), mode='bilinear', align_corners=True) id_real = self.id_recognition_model(driving_preprocessed) id_generated = self.id_recognition_model(generated_preprocessed) #cosine value = 1 - torch.nn.functional.cosine_similarity(id_real, id_generated, dim=1) loss_values['id'] = self.loss_weights['id'] * value.mean() + # gan loss + #if self.loss_weights['discriminator_gan'] != 0: + # if step % self.discriminator_steps == 0: + # discriminator_maps_generated = self.discriminator(generated['prediction'], x['driving']) + # discriminator_maps_real = self.discriminator(x['driving'], x['driving']) + # discriminator_loss = 0 + # for scale in discriminator_maps_generated: + # value_generated = discriminator_maps_generated[scale] + # value_real = discriminator_maps_real[scale] + # discriminator_loss += self.loss_weights['discriminator_gan'] * ( + # self.gan_loss(value_real, False, True) + self.gan_loss(value_generated, True, False)) / len( + # discriminator_maps_generated) + # loss_values['discriminator'] = discriminator_loss + + return loss_values, generated diff --git a/run.py b/run.py index ddb5c17..0773528 100644 --- a/run.py +++ b/run.py @@ -40,9 +40,12 @@ parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") parser.add_argument("--kp_detector", default=None, help="path to kp_detector checkpoint to restore") parser.add_argument("--bg_predictor", default=None, help="path to bg_predictor checkpoint to restore") + parser.add_argument("--dense_motion_network", default=None, help="path to dense_motion_network checkpoint to restore") + parser.add_argument("--inpainting", default=None, help="path to inpainting checkpoint to restore") parser.add_argument("--detect_anomaly", action="store_true", help="detect anomaly in autograd") + opt = parser.parse_args() with open(opt.config) as f: config = yaml.load(f) @@ -100,7 +103,10 @@ log_dir, dataset, optimizer_class=optimizer_class, kp_detector_checkpoint=opt.kp_detector, - bg_predictor_checkpoint=opt.bg_predictor) + bg_predictor_checkpoint=opt.bg_predictor, + dense_motion_network_checkpoint=opt.dense_motion_network, + inpainting_checkpoint=opt.inpainting + ) elif opt.mode == 'train_avd': print("Training Animation via Disentaglement...") train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, diff --git a/save_model_only.py b/save_model_only.py index dbf2f23..faedc21 100644 --- a/save_model_only.py +++ b/save_model_only.py @@ -3,6 +3,7 @@ import yaml +from gan import MultiScaleDiscriminator from modules.inpainting_network import InpaintingNetwork from modules.keypoint_detector import KPDetector from modules.bg_motion_predictor import BGMotionPredictor @@ -14,8 +15,9 @@ if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('config', help='path to config') - parser.add_argument('--checkpoint', '-c', help='path to checkpoint to restore') + parser.add_argument('checkpoint', help='path to checkpoint to restore') parser.add_argument('--target-checkpoint', '-t', default=None, help='path to checkpoint to save') + parser.add_argument('--discriminator', '-d', action='store_true', help='save discriminator') opt = parser.parse_args() @@ -25,35 +27,43 @@ checkpoint = torch.load(opt.checkpoint) print(checkpoint.keys()) - inpainting = InpaintingNetwork(**config['model_params']['generator_params'], - **config['model_params']['common_params']) - + **config['model_params']['common_params']) kp_detector = KPDetector(**config['model_params']['common_params']) dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'], **config['model_params']['dense_motion_params']) - bg_predictor = None if 'bg_predictor' in checkpoint: bg_predictor = BGMotionPredictor() + else: + print("No bg_predictor in checkpoint") avd_network = None if 'avd_network' in checkpoint: avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'], - **config['model_params']['avd_network_params']) + **config['model_params']['avd_network_params']) - target_checkpoint = opt.target_checkpoint if opt.target_checkpoint is not None else os.path.join(opt.checkpoint.split('.')[0] + '_model_only.pth.tar') + if opt.discriminator and 'discriminator' in checkpoint: + discriminator = MultiScaleDiscriminator(scales=[1], d=64) + target_checkpoint = opt.target_checkpoint if opt.target_checkpoint is not None else os.path.join( + opt.checkpoint.split('.')[0] + '_model_only.pth.tar') inpainting.load_state_dict(checkpoint['inpainting_network']) kp_detector.load_state_dict(checkpoint['kp_detector']) dense_motion_network.load_state_dict(checkpoint['dense_motion_network']) + if bg_predictor is not None: + print("Loading bg_predictor") bg_predictor.load_state_dict(checkpoint['bg_predictor']) if avd_network is not None: + print("Loading avd_network") avd_network.load_state_dict(checkpoint['avd_network']) + if opt.discriminator and 'discriminator' in checkpoint: + print("Loading discriminator") + discriminator.load_state_dict(checkpoint['discriminator']) save_dict = { 'inpainting_network': inpainting.state_dict(), @@ -64,5 +74,9 @@ save_dict['bg_predictor'] = bg_predictor.state_dict() if avd_network is not None: save_dict['avd_network'] = avd_network.state_dict() + if opt.discriminator and 'discriminator' in checkpoint: + save_dict['discriminator'] = discriminator.state_dict() + + print(f"Saving keys: {save_dict.keys()}") - torch.save(save_dict, target_checkpoint) \ No newline at end of file + torch.save(save_dict, target_checkpoint) diff --git a/train.py b/train.py index d0d2aa6..fb4bfea 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,8 @@ import torch from torch.utils.data import DataLoader -from gan import MultiScaleDiscriminator, discriminator_adversarial_loss, generator_adversarial_loss +from gan import MultiScaleDiscriminator, discriminator_adversarial_loss, generator_adversarial_loss, \ + weak_feature_matching_loss from logger import Logger from modules.model import GeneratorFullModel from torch.optim.lr_scheduler import OneCycleLR @@ -14,6 +15,8 @@ from accelerate import Accelerator from torchview import draw_graph +from utils import load_params + torch.backends.cudnn.benchmark = True accelerator = Accelerator() @@ -24,8 +27,11 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne optimizer_class=torch.optim.Adam, kp_detector_checkpoint=None, bg_predictor_checkpoint=None, + dense_motion_network_checkpoint=None, + inpainting_checkpoint=None, ): train_params = config['train_params'] + scheduler_params = config['train_params'].get('scheduler_params', {}) optimizer_params = config['train_params'].get('optimizer_params', {}) optimizer = optimizer_class( @@ -35,10 +41,14 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne 'initial_lr': train_params['lr_generator']}], lr=train_params['lr_generator'], **optimizer_params) - discriminator = MultiScaleDiscriminator(scales=[1], d=64) - optimizer_discriminator = optimizer_class( - [{'params': list(discriminator.parameters()), 'initial_lr': train_params['lr_discriminator']}], - lr=train_params['lr_discriminator'], **optimizer_params) + if train_params['loss_weights'].get('discriminator_gan', 0) > 0: + discriminator = MultiScaleDiscriminator(scales=[1], d=64) + optimizer_discriminator = optimizer_class( + [{'params': list(discriminator.parameters()), 'initial_lr': train_params['lr_discriminator']}], + lr=train_params['lr_discriminator'], **optimizer_params) + else: + discriminator = None + optimizer_discriminator = None @@ -62,16 +72,23 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne start_epoch = 0 if kp_detector_checkpoint is not None: - kp_params = torch.load(kp_detector_checkpoint) - kp_detector.load_state_dict(kp_params['kp_detector']) - print('load kp detector success') + load_params(kp_detector, kp_detector_checkpoint, name='kp_detector', + strict=True) if bg_predictor_checkpoint is not None: - bg_params = torch.load(bg_predictor_checkpoint) - bg_predictor.load_state_dict(bg_params['bg_predictor']) - print('load bg predictor success') + load_params(bg_predictor, bg_predictor_checkpoint, name='bg_predictor', + strict=True) + if inpainting_checkpoint is not None: + load_params(inpainting_network, inpainting_checkpoint, name='inpainting_network', + strict=False, find_alternative_weights=True) + + print('load inpainting network success') + if dense_motion_network_checkpoint is not None: + load_params(dense_motion_network, dense_motion_network_checkpoint, name='dense_motion_network', + strict=False, find_alternative_weights=True) freeze_kp_detector = train_params.get('freeze_kp_detector', False) freeze_bg_predictor = train_params.get('freeze_bg_predictor', False) + freeze_dense_motion_network = train_params.get('freeze_dense_motion_network', False) if freeze_kp_detector: print('freeze kp detector') kp_detector.eval() @@ -82,6 +99,11 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne bg_predictor.eval() for param in bg_predictor.parameters(): param.requires_grad = False + if freeze_dense_motion_network: + print('freeze dense motion network') + dense_motion_network.eval() + for param in dense_motion_network.parameters(): + param.requires_grad = False if 'num_repeats' in train_params or train_params['num_repeats'] != 1: dataset = DatasetRepeater(dataset, train_params['num_repeats']) @@ -94,22 +116,26 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne scheduler_optimizer = OneCycleLR(optimizer, max_lr=train_params['lr_generator'], total_steps=(len(dataset) // train_params['batch_size']) * train_params[ 'num_epochs'], - last_epoch=last_epoch + last_epoch=last_epoch, + **scheduler_params ) - discriminator_scheduler = OneCycleLR(optimizer_discriminator, max_lr=train_params['lr_discriminator'], - total_steps=(len(dataset) // train_params['batch_size']) * train_params[ - 'num_epochs'], - last_epoch=last_epoch - ) + if discriminator: + discriminator_scheduler = OneCycleLR(optimizer_discriminator, max_lr=train_params['lr_discriminator'], + total_steps=(len(dataset) // train_params['batch_size']) * train_params[ + 'num_epochs'], + last_epoch=last_epoch, + **scheduler_params + ) - discriminator, optimizer_discriminator, discriminator_scheduler = accelerator.prepare(discriminator, optimizer_discriminator, discriminator_scheduler) + discriminator, optimizer_discriminator, discriminator_scheduler = accelerator.prepare(discriminator, optimizer_discriminator, discriminator_scheduler) scheduler_bg_predictor = None if bg_predictor: scheduler_bg_predictor = OneCycleLR(optimizer_bg_predictor, max_lr=train_params['lr_generator'], total_steps=(len(dataset) // train_params['batch_size']) * train_params[ 'num_epochs'], - last_epoch=last_epoch + last_epoch=last_epoch, + **scheduler_params ) bg_predictor, optimizer_bg_predictor = accelerator.prepare(bg_predictor, optimizer_bg_predictor) @@ -149,6 +175,7 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq'], models=model_list, + train_config=config, ) as logger: for epoch in trange(start_epoch, train_params['num_epochs']): i = 0 @@ -159,26 +186,33 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne losses_generator, generated = generator_full(x, epoch) disc_loss = torch.zeros(1, device=x['driving'].device) gen_loss = torch.zeros(1, device=x['driving'].device) - - if i % 2 == 0: - disc_pred_fake = discriminator(generated['prediction']) - disc_pred_real = discriminator(x['driving']) - for j in range(len(disc_pred_real)): # number of scales - disc_loss += discriminator_adversarial_loss(disc_pred_real[j], disc_pred_fake[j]) - else: - features_fake, fake_preds = discriminator.forward_with_features(generated['prediction']) - features_real, _ = discriminator.forward_with_features(x['driving']) - for k in range(len(fake_preds)): - gen_loss += generator_adversarial_loss(fake_preds[k]) - - losses_generator['gen'] = gen_loss + feat_match_loss = torch.zeros(1, device=x['driving'].device) + + # discriminator / generator adversarial loss + if discriminator: + if i % 2 == 0: + disc_pred_fake = discriminator(generated['prediction']) + disc_pred_real = discriminator(x['driving']) + for j in range(len(disc_pred_real)): # number of scales + disc_loss += discriminator_adversarial_loss(disc_pred_real[j], disc_pred_fake[j]) + disc_loss *= train_params['loss_weights']['discriminator_gan'] + else: + features_fake, fake_preds = discriminator.forward_with_features(generated['prediction']) + features_real, _ = discriminator.forward_with_features(x['driving']) + for k in range(len(fake_preds)): + gen_loss += generator_adversarial_loss(fake_preds[k]) + feat_match_loss += weak_feature_matching_loss(features_fake[k], features_real[k], + start_layer=0) + + losses_generator['gen'] = gen_loss * train_params['loss_weights']['generator_gan'] + losses_generator['feat_match'] = feat_match_loss * train_params['loss_weights']['generator_feat_match'] loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) - - if i % 2 == 0: + # discriminator step + if i % 2 == 0 and discriminator: accelerator.backward(disc_loss, retain_graph=True) clip_grad_norm_(discriminator.parameters(), max_norm=10, norm_type=math.inf) @@ -203,13 +237,14 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne scheduler_optimizer.step() optimizer.zero_grad() - discriminator_scheduler.step() + if discriminator: + discriminator_scheduler.step() losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()} lrs = { 'lr_generator': scheduler_optimizer.get_last_lr()[0], 'lr_bg_predictor': scheduler_bg_predictor.get_last_lr()[0] if bg_predictor else 0, - 'lr_discriminator': discriminator_scheduler.get_last_lr()[0] + 'lr_discriminator': discriminator_scheduler.get_last_lr()[0] if discriminator else 0, } losses['disc'] = disc_loss.mean().detach().data.cpu().numpy() logger.log_iter(losses=losses, others=lrs) @@ -223,7 +258,7 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne 'optimizer': optimizer, 'bg_predictor': accelerator.unwrap_model(bg_predictor) if bg_predictor else None, 'optimizer_bg_predictor': optimizer_bg_predictor, - 'discriminator': accelerator.unwrap_model(discriminator), + 'discriminator': accelerator.unwrap_model(discriminator) if discriminator else None, 'optimizer_discriminator': optimizer_discriminator, } diff --git a/utils.py b/utils.py index c476d00..d87ad23 100644 --- a/utils.py +++ b/utils.py @@ -4,6 +4,8 @@ import imageio import logging +import torch + logger = logging.getLogger("TPSMM") IMAGE_FORMATS = ["png", "jpg", "jpeg", "bmp", "tif", "tiff"] @@ -134,3 +136,52 @@ def close(self): return None +def load_params(model, path, map_location=None, name="model", + strict=True, find_alternative_weights=False): + if path is None: + return + + if os.path.isdir(path): + path = os.path.join(path, "model.pt") + + if not os.path.exists(path): + logger.warning(f"Could not find checkpoint at {path}") + return + + logger.info(f"Loading checkpoint from {path}") + + checkpoint = torch.load(path, map_location=map_location) + try: + model.load_state_dict(checkpoint[name], strict=strict) + except Exception as e: + logger.error(f"Could not load model from {path}: {e}") + + if strict: + raise e + + + without_match = set() + for k, v in model.items(): + if k in model.state_dict(): + if v.shape == model.state_dict()[k].shape: + model.state_dict()[k].copy_(v) + else: + without_match.add(k) + logger.warning(f'Could not find direct match for {k} in checkpoint') + else: + without_match.add(k) + logger.warning(f'Could not find direct match for {k} in checkpoint') + + if find_alternative_weights: + for k in without_match: + logger.info(f"Trying to find alternative weights for {k}") + for ckpt_k, ckpt_v in checkpoint[name].items(): + if ckpt_v.shape == model.state_dict()[k].shape: + logger.info(f"Found alternative weights for {k} in {ckpt_k}") + model.state_dict()[k].copy_(ckpt_v) + break + else: + logger.warning(f"Could not find alternative weights for {k}") + + + From 26635bf004cf8eba09c4a6b3e7c48cd3c211900e Mon Sep 17 00:00:00 2001 From: TGG Date: Thu, 20 Jul 2023 10:44:32 +0200 Subject: [PATCH 11/30] change pyyaml version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 9e83ed8..f2ce41f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ pyparsing==2.4.7 python-dateutil==2.8.2 pytz==2021.1 PyWavelets -PyYAML==5.4.1 +PyYAML>=6.0.1 scikit-image scikit-learn scipy From 1cc1644baab0f60b6879040ceedbd63d0213cb5a Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 13:47:38 +0200 Subject: [PATCH 12/30] update --- config/vox-256-deeper-other.yaml | 98 ++++++++++++++++++++++++++++++++ config/vox-256-deeper.yaml | 7 ++- config/vox-512-deeper.yaml | 2 +- config/vox-512-finetune.yaml | 52 ++++++++++------- config/vox-768-finetune.yaml | 21 +++++-- logger.py | 15 +++-- gan.py => modules/gan.py | 0 modules/inpainting_network.py | 62 ++++++++++++++++++-- train.py | 31 +++++----- utils.py | 41 +++++++++---- video-preprocessing | 2 +- 11 files changed, 264 insertions(+), 67 deletions(-) create mode 100644 config/vox-256-deeper-other.yaml rename gan.py => modules/gan.py (100%) diff --git a/config/vox-256-deeper-other.yaml b/config/vox-256-deeper-other.yaml new file mode 100644 index 0000000..e25949e --- /dev/null +++ b/config/vox-256-deeper-other.yaml @@ -0,0 +1,98 @@ +name: vox-256-deeper-other + +dataset_params: + root_dir: ../data/vox512_webp + frame_shape: 256,256,3 + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 4 + concat_encode: True + skip_block_type: depthwise + dropout: 0.1 + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 # might make sense to set to 0.5 because of the additional occlusion (4=>5) + occlusion_num: 5 + + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + +train_params: + num_epochs: 100 + num_repeats: 5 + lr_generator: 2.0e-4 + lr_discriminator: 2.0e-4 + batch_size: 8 + scales: [1, 0.5, 0.25, 0.125] + dataloader_workers: 8 + checkpoint_freq: 10 + dropout_epoch: 30 + dropout_maxp: 0.3 + dropout_startp: 0.1 + dropout_inc_epoch: 10 + bg_start: 101 + freeze_kp_detector: True + freeze_bg_predictor: True + freeze_dense_motion: False + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + warp_loss: 10 + bg: 0 + l2: 0 + id: 0.1 + huber: 0 + generator_gan: 1 + generator_feat_match: 0 + discriminator_gan: 1 + optimizer: 'adamw' + optimizer_params: + betas: [ 0.9, 0.999 ] + weight_decay: 1.0e-3 + scheduler: 'onecycle' + scheduler_params: + pct_start: 0.3 + + +train_avd_params: + num_epochs: 100 + num_repeats: 1 + batch_size: 8 + dataloader_workers: 6 + checkpoint_freq: 1 + epoch_milestones: [10, 20] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' \ No newline at end of file diff --git a/config/vox-256-deeper.yaml b/config/vox-256-deeper.yaml index 047a65c..3216459 100644 --- a/config/vox-256-deeper.yaml +++ b/config/vox-256-deeper.yaml @@ -1,3 +1,5 @@ +name: vox-256-deeper + dataset_params: root_dir: ./vox512_webp frame_shape: 256,256,3 @@ -23,9 +25,12 @@ model_params: block_expansion: 64 max_features: 512 num_down_blocks: 4 + concat_encode: False + use_skip_blocks: False + dropout: 0.0 dense_motion_params: block_expansion: 64 - max_features: 512 + max_features: 1024 num_blocks: 5 scale_factor: 0.25 # might make sense to set to 0.5 because of the additional occlusion (4=>5) occlusion_num: 5 diff --git a/config/vox-512-deeper.yaml b/config/vox-512-deeper.yaml index aa791a8..f3de641 100644 --- a/config/vox-512-deeper.yaml +++ b/config/vox-512-deeper.yaml @@ -24,7 +24,7 @@ model_params: max_features: 512 num_down_blocks: 4 concat_encode: False - use_skip_blocks: True + use_skip_blocks: False dense_motion_params: block_expansion: 64 max_features: 1024 diff --git a/config/vox-512-finetune.yaml b/config/vox-512-finetune.yaml index 4574848..cd45237 100644 --- a/config/vox-512-finetune.yaml +++ b/config/vox-512-finetune.yaml @@ -1,7 +1,6 @@ -# Use this file to finetune from a pretrained 256x256 model dataset_params: root_dir: ./video-preprocessing/vox2-768 - frame_shape: 512,512,3 + frame_shape: 256,256,3 id_sampling: True augmentation_params: flip_param: @@ -24,6 +23,7 @@ model_params: block_expansion: 64 max_features: 512 num_down_blocks: 3 + dropout: 0.1 dense_motion_params: block_expansion: 64 max_features: 1024 @@ -35,22 +35,21 @@ model_params: train_params: - num_epochs: 40 - num_repeats: 4 - # Higher LR seems to bring problems when finetuning - lr_generator: 2.0e-4 - lr_discriminator: 2.0e-3 - batch_size: 2 - scales: [1, 0.5, 0.25, 0.125, 0.0625] - dataloader_workers: 6 - checkpoint_freq: 5 - dropout_epoch: 2 + num_epochs: 50 + num_repeats: 10 + lr_generator: 2.0e-5 + lr_discriminator: 2.0e-5 + batch_size: 4 + scales: [1, 0.5, 0.25, 0.125] + dataloader_workers: 12 + checkpoint_freq: 10 + dropout_epoch: 0 dropout_maxp: 0.3 dropout_startp: 0.1 - dropout_inc_epoch: 1 - bg_start: 41 - freeze_kp_detector: True - freeze_bg_predictor: True + dropout_inc_epoch: 10 + bg_start: 6 + freeze_kp_detector: False + freeze_bg_predictor: False transform_params: sigma_affine: 0.05 sigma_tps: 0.005 @@ -59,19 +58,28 @@ train_params: perceptual: [10, 10, 10, 10, 10] equivariance_value: 10 warp_loss: 10 - bg: 10 + bg: 0 l2: 0 + id: 1 + huber: 0 + generator_gan: 10 + generator_feat_match: 1000 + discriminator_gan: 10 + optimizer: 'adamw' optimizer_params: - betas: [0.9, 0.999] - weight_decay: 0.1 + betas: [ 0.9, 0.999 ] + weight_decay: 1.0e-3 + scheduler: 'onecycle' + scheduler_params: + pct_start: 0.01 train_avd_params: - num_epochs: 200 + num_epochs: 100 num_repeats: 1 - batch_size: 4 + batch_size: 8 dataloader_workers: 6 - checkpoint_freq: 10 + checkpoint_freq: 1 epoch_milestones: [10, 20] lr: 1.0e-3 lambda_shift: 1 diff --git a/config/vox-768-finetune.yaml b/config/vox-768-finetune.yaml index 6ac37bf..e743baa 100644 --- a/config/vox-768-finetune.yaml +++ b/config/vox-768-finetune.yaml @@ -1,4 +1,5 @@ # Use this file to finetune from a pretrained 256x256 model +name: vox-768-finetune dataset_params: root_dir: ./video-preprocessing/vox2-768 frame_shape: 768,768,3 @@ -36,14 +37,15 @@ model_params: train_params: visualize_model: False - num_epochs: 80 - num_repeats: 10 + num_epochs: 50 + num_repeats: 1 # Higher LR seems to bring problems when finetuning - lr_generator: 3.0e-5 - batch_size: 2 + lr_generator: 2.0e-6 + lr_discriminator: 2.0e-5 + batch_size: 1 scales: [1, 0.5, 0.25, 0.125, 0.0625] dataloader_workers: 8 - checkpoint_freq: 2 + checkpoint_freq: 1 dropout_epoch: 0 dropout_maxp: 0.3 dropout_startp: 0.1 @@ -60,10 +62,19 @@ train_params: equivariance_value: 10 warp_loss: 10 bg: 10 + id: 0.1 + l2: 0 + huber: 0 + generator_gan: 10 + generator_feat_match: 10 + discriminator_gan: 10 optimizer: 'adamw' optimizer_params: betas: [ 0.9, 0.999 ] weight_decay: 0.1 + scheduler: 'onecycle' + scheduler_params: + pct_start: 0.01 train_avd_params: num_epochs: 200 diff --git a/logger.py b/logger.py index ea625d6..7b498f3 100644 --- a/logger.py +++ b/logger.py @@ -15,7 +15,8 @@ class Logger: def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None, - zfill_num=8, log_file_name='log.txt', models=()): + zfill_num=8, log_file_name='log.txt', models=(), + train_config=None): self.models = models self.loss_list = [] @@ -31,9 +32,11 @@ def __init__(self, log_dir, checkpoint_freq=50, visualizer_params=None, self.epoch = 0 self.best_loss = float('inf') self.names = None - wandb.init(project="TPSMM", dir=log_dir) + wandb.init(project="TPSMM", dir=log_dir, config=train_config, + name=train_config['name']) for model in models: - wandb.watch(model) + if model is not None: + wandb.watch(model, log="all", log_freq=500) def log_scores(self, loss_names): loss_mean = np.array(self.loss_list).mean(axis=0) @@ -62,7 +65,7 @@ def save_cpk(self, emergent=False): @staticmethod def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network=None, kp_detector=None, bg_predictor=None, avd_network=None, optimizer=None, optimizer_bg_predictor=None, - optimizer_avd=None, discriminator=None, discriminator_optimizer=None): + optimizer_avd=None, discriminator=None, optimizer_discriminator=None): checkpoint = torch.load(checkpoint_path) if inpainting_network is not None: inpainting_network.load_state_dict(checkpoint['inpainting_network']) @@ -84,8 +87,8 @@ def load_cpk(checkpoint_path, inpainting_network=None, dense_motion_network=None optimizer_avd.load_state_dict(checkpoint['optimizer_avd']) if discriminator is not None and 'discriminator' in checkpoint: discriminator.load_state_dict(checkpoint['discriminator']) - if discriminator_optimizer is not None and 'optimizer_discriminator' in checkpoint: - discriminator_optimizer.load_state_dict(checkpoint['optimizer_discriminator']) + if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint: + optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) epoch = -1 if 'epoch' in checkpoint: diff --git a/gan.py b/modules/gan.py similarity index 100% rename from gan.py rename to modules/gan.py diff --git a/modules/inpainting_network.py b/modules/inpainting_network.py index 73b101f..f77e87a 100644 --- a/modules/inpainting_network.py +++ b/modules/inpainting_network.py @@ -4,20 +4,62 @@ from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d from modules.dense_motion import DenseMotionNetwork +import torch + + +def get_kernel(): + """ + See https://setosa.io/ev/image-kernels/ + """ + + k1 = torch.tensor([[0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625]], dtype=torch.float32) + + # Sharpening Spatial Kernel, used in paper + k2 = torch.tensor([[-1, -1, -1], + [-1, 8, -1], + [-1, -1, -1]], dtype=torch.float32) + + k3 = torch.tensor([[0, -1, 0], + [-1, 5, -1], + [0, -1, 0]], dtype=torch.float32) + + return k1, k2, k3 + + +def build_sharp_blocks(layer): + """ + Sharp Blocks + """ + # Get number of channels in the feature + out_channels = layer.shape[0] + in_channels = layer.shape[1] + # Get kernel + _, w, _ = get_kernel() + # Change dimension + w = torch.unsqueeze(w, dim=0) # add an out_channel dimension at the beginning + # Repeat filter by out_channels times to get (out_channels, H, W) + w = w.repeat(out_channels, 1, 1) + # Expand dimension + w = torch.unsqueeze(w, dim=1) # add an in_channel dimension after out_channels + return w + class InpaintingNetwork(nn.Module): """ Inpaint the missing regions and reconstruct the Driving image. """ def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, multi_mask = True, - concat_encode=True, use_skip_blocks=False, - **kwargs): + concat_encode=True, skip_block_type=None, + dropout=0.0, **kwargs): super(InpaintingNetwork, self).__init__() self.num_down_blocks = num_down_blocks self.multi_mask = multi_mask self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) self.concat_encode = concat_encode + self.dropout = dropout down_blocks = [] up_blocks = [] @@ -27,8 +69,15 @@ def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, in_features = min(max_features, block_expansion * (2 ** i)) out_features = min(max_features, block_expansion * (2 ** (i + 1))) down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) - if use_skip_blocks: - skip_blocks.append(nn.Conv2d(in_features, out_features, kernel_size=(1, 1))) + if skip_block_type == 'sharp': + # depthwise conv + skip_blocks.append(nn.Conv2d(out_features, out_features, kernel_size=(3, 3), padding=(1, 1), groups=out_features, + bias=False)) + weight = build_sharp_blocks(skip_blocks[-1].weight) + skip_blocks[-1].weight = nn.Parameter(weight, requires_grad=False) + elif skip_block_type == 'depthwise': + skip_blocks.append(nn.Conv2d(out_features, out_features, kernel_size=(3, 3), padding=(1, 1), groups=out_features, + bias=False)) if concat_encode: decoder_in_feature = out_features * 2 else: @@ -43,7 +92,7 @@ def __init__(self, num_channels, block_expansion, max_features, num_down_blocks, self.up_blocks = nn.ModuleList(up_blocks[::-1]) self.resblock = nn.ModuleList(resblock[::-1]) if skip_blocks: - self.skip_blocks = nn.ModuleList(skip_blocks[::-1]) + self.skip_blocks = nn.ModuleList(skip_blocks) else: self.skip_blocks = None @@ -91,6 +140,9 @@ def forward(self, source_image, dense_motion): for i in range(self.num_down_blocks): + if self.dropout > 0: + out = F.dropout2d(out, p=self.dropout, training=self.training) + out = self.resblock[2*i](out) # e.g. 0, 2, 4, 6 out = self.resblock[2*i+1](out) # e.g. 1, 3, 5, 7 out = self.up_blocks[i](out) # e.g. 0, 1, 2, 3 diff --git a/train.py b/train.py index fb4bfea..4af0b8e 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,7 @@ import torch from torch.utils.data import DataLoader -from gan import MultiScaleDiscriminator, discriminator_adversarial_loss, generator_adversarial_loss, \ +from modules.gan import MultiScaleDiscriminator, discriminator_adversarial_loss, generator_adversarial_loss, \ weak_feature_matching_loss from logger import Logger from modules.model import GeneratorFullModel @@ -28,7 +28,7 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne kp_detector_checkpoint=None, bg_predictor_checkpoint=None, dense_motion_network_checkpoint=None, - inpainting_checkpoint=None, + inpainting_checkpoint=None, ): train_params = config['train_params'] scheduler_params = config['train_params'].get('scheduler_params', {}) @@ -45,14 +45,13 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne discriminator = MultiScaleDiscriminator(scales=[1], d=64) optimizer_discriminator = optimizer_class( [{'params': list(discriminator.parameters()), 'initial_lr': train_params['lr_discriminator']}], - lr=train_params['lr_discriminator'], **optimizer_params) + lr=train_params['lr_discriminator'], **optimizer_params) else: discriminator = None optimizer_discriminator = None - - - torchinfo.summary(discriminator, input_size=(1, 3, 256, 256)) + if discriminator is not None: + torchinfo.summary(discriminator, input_size=(1, 3, 256, 256)) optimizer_bg_predictor = None if bg_predictor: @@ -121,13 +120,15 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne ) if discriminator: discriminator_scheduler = OneCycleLR(optimizer_discriminator, max_lr=train_params['lr_discriminator'], - total_steps=(len(dataset) // train_params['batch_size']) * train_params[ - 'num_epochs'], - last_epoch=last_epoch, - **scheduler_params - ) + total_steps=(len(dataset) // train_params['batch_size']) * train_params[ + 'num_epochs'], + last_epoch=last_epoch, + **scheduler_params + ) - discriminator, optimizer_discriminator, discriminator_scheduler = accelerator.prepare(discriminator, optimizer_discriminator, discriminator_scheduler) + discriminator, optimizer_discriminator, discriminator_scheduler = accelerator.prepare(discriminator, + optimizer_discriminator, + discriminator_scheduler) scheduler_bg_predictor = None if bg_predictor: @@ -144,7 +145,6 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne bg_start = train_params['bg_start'] - inpainting_network, kp_detector, dense_motion_network, optimizer, scheduler_optimizer, dataloader, generator_full = accelerator.prepare( inpainting_network, kp_detector, dense_motion_network, optimizer, scheduler_optimizer, dataloader, generator_full) @@ -205,12 +205,12 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne start_layer=0) losses_generator['gen'] = gen_loss * train_params['loss_weights']['generator_gan'] - losses_generator['feat_match'] = feat_match_loss * train_params['loss_weights']['generator_feat_match'] + losses_generator['feat_match'] = feat_match_loss * train_params['loss_weights'][ + 'generator_feat_match'] loss_values = [val.mean() for val in losses_generator.values()] loss = sum(loss_values) - # discriminator step if i % 2 == 0 and discriminator: accelerator.backward(disc_loss, retain_graph=True) @@ -228,7 +228,6 @@ def train(config, inpainting_network, kp_detector, bg_predictor, dense_motion_ne optimizer.step() - if bg_predictor and epoch >= bg_start and not freeze_bg_predictor: optimizer_bg_predictor.step() optimizer_bg_predictor.zero_grad() diff --git a/utils.py b/utils.py index d87ad23..c6295fe 100644 --- a/utils.py +++ b/utils.py @@ -136,6 +136,16 @@ def close(self): return None +def get_layer_type_from_key(key): + if "conv" in key: + return "conv" + elif "norm" in key: + return "norm" + elif "occlusion" in key: + return "occlusion" + else: + return None + def load_params(model, path, map_location=None, name="model", strict=True, find_alternative_weights=False): if path is None: @@ -161,27 +171,38 @@ def load_params(model, path, map_location=None, name="model", without_match = set() - for k, v in model.items(): - if k in model.state_dict(): - if v.shape == model.state_dict()[k].shape: - model.state_dict()[k].copy_(v) + for k, v in model.state_dict().items(): + if k in checkpoint[name]: + if v.shape == checkpoint[name][k].shape: + logger.info(f"Found direct match for {k}") + model.state_dict()[k].copy_(checkpoint[name][k]) else: without_match.add(k) logger.warning(f'Could not find direct match for {k} in checkpoint') else: without_match.add(k) - logger.warning(f'Could not find direct match for {k} in checkpoint') + logger.warning(f'Could not find {k} in checkpoint') + if find_alternative_weights: - for k in without_match: - logger.info(f"Trying to find alternative weights for {k}") + temp_without_match = set(without_match) + + logger.info(f"Trying to find alternative weights for {len(without_match)} keys") + for k in temp_without_match: + weights_type = k.split(".")[-1] + layer_type = get_layer_type_from_key(k) + for ckpt_k, ckpt_v in checkpoint[name].items(): + if layer_type is None or layer_type not in ckpt_k or weights_type not in ckpt_k: + continue + if ckpt_v.shape == model.state_dict()[k].shape: - logger.info(f"Found alternative weights for {k} in {ckpt_k}") + logger.warning(f"Found alternative weights for {k} in {ckpt_k}") model.state_dict()[k].copy_(ckpt_v) + without_match.remove(k) break - else: - logger.warning(f"Could not find alternative weights for {k}") + + logger.warning(f"Could not load {len(without_match)} keys from checkpoint") diff --git a/video-preprocessing b/video-preprocessing index ac40aac..c90f840 160000 --- a/video-preprocessing +++ b/video-preprocessing @@ -1 +1 @@ -Subproject commit ac40aac58657a3d8db85421cd4afcf465e86ead1 +Subproject commit c90f840e6c4e79b4c98656ed271c8e299665cb8a From 7a237dfc553d95989753c51ac73d8170f9e29b65 Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 13:50:34 +0200 Subject: [PATCH 13/30] update reqs --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f2ce41f..0709032 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,4 +23,5 @@ six==1.16.0 torch==2.0.1 torchvision tqdm==4.62.3 -wandb \ No newline at end of file +wandb +accelerate From 9d5a45995c9c4698a25211235bc8427e7eb40dc6 Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 13:51:55 +0200 Subject: [PATCH 14/30] remove albumentations --- frames_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/frames_dataset.py b/frames_dataset.py index edeaf71..91d83a1 100644 --- a/frames_dataset.py +++ b/frames_dataset.py @@ -1,6 +1,5 @@ import os -from albumentations import AdvancedBlur from skimage import io, img_as_float32 from skimage.color import gray2rgb from sklearn.model_selection import train_test_split From 8fbecab7ff78d8cf1208615a0021600ef7aad7fd Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 13:52:32 +0200 Subject: [PATCH 15/30] add torchinfo --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 0709032..8056fe4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,4 @@ torchvision tqdm==4.62.3 wandb accelerate +torchinfo From 50d4ecac1730b02bf812318bd23589aa5d947b8d Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 13:59:33 +0200 Subject: [PATCH 16/30] pin all dependencies --- requirements.txt | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8056fe4..27cd39e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,27 +2,27 @@ cffi==1.14.6 cycler==0.10.0 decorator==5.1.0 face-alignment==1.4.0 -imageio -imageio-ffmpeg +imageio==2.31.1 +imageio-ffmpeg==0.4.8 kiwisolver==1.3.2 matplotlib==3.4.3 networkx==2.6.3 -numpy +numpy==1.24.4 pandas==1.3.3 -Pillow +Pillow==9.5.0 pycparser==2.20 pyparsing==2.4.7 python-dateutil==2.8.2 pytz==2021.1 -PyWavelets +PyWavelets==1.1.1 PyYAML>=6.0.1 -scikit-image -scikit-learn -scipy +scikit-image==0.18.3 +scikit-learn==1.2.2 +scipy==1.11.1 six==1.16.0 torch==2.0.1 -torchvision +torchvision==0.15.1 tqdm==4.62.3 -wandb -accelerate -torchinfo +wandb==0.15.10 +accelerate==0.22.0 +torchinfo==1.8.0 From 87badb96d9613b98eda380ec8a068fb8a2590625 Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 14:00:19 +0200 Subject: [PATCH 17/30] pin all dependencies --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 27cd39e..aee8a92 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ scikit-learn==1.2.2 scipy==1.11.1 six==1.16.0 torch==2.0.1 -torchvision==0.15.1 +torchvision==0.15.2 tqdm==4.62.3 wandb==0.15.10 accelerate==0.22.0 From ad87d38fe07ab29d381692d9953a892457c5182d Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 14:01:40 +0200 Subject: [PATCH 18/30] add missing requirement --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index aee8a92..563b13f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,4 @@ tqdm==4.62.3 wandb==0.15.10 accelerate==0.22.0 torchinfo==1.8.0 +facenet_pytorch==2.5.3 \ No newline at end of file From 1c208f076a8e2fffd1c3d72494a6ad9c66105881 Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 14:02:21 +0200 Subject: [PATCH 19/30] add missing requirement --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 563b13f..5f04516 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,4 +26,5 @@ tqdm==4.62.3 wandb==0.15.10 accelerate==0.22.0 torchinfo==1.8.0 -facenet_pytorch==2.5.3 \ No newline at end of file +facenet_pytorch==2.5.3 +torchview==0.2.6 From 2352e282026eef8f7386689822138e70cda7dccb Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 14:03:06 +0200 Subject: [PATCH 20/30] add missing requirement --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 5f04516..22cac10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ accelerate==0.22.0 torchinfo==1.8.0 facenet_pytorch==2.5.3 torchview==0.2.6 +graphviz==0.20.1 From e989d16ebe4925e77c70e4cb7b918715c6d20ec7 Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 14:03:40 +0200 Subject: [PATCH 21/30] add missing requirement --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 22cac10..039a3fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,4 @@ torchinfo==1.8.0 facenet_pytorch==2.5.3 torchview==0.2.6 graphviz==0.20.1 +bitsandbytes==0.37.2 \ No newline at end of file From 09054ba0ac0b9249a2272a06ec06d769e2e496e6 Mon Sep 17 00:00:00 2001 From: TGG Date: Fri, 8 Sep 2023 14:07:25 +0200 Subject: [PATCH 22/30] remove bitsandbytes --- requirements.txt | 1 - run.py | 29 ++++++++++++----------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/requirements.txt b/requirements.txt index 039a3fe..22cac10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,3 @@ torchinfo==1.8.0 facenet_pytorch==2.5.3 torchview==0.2.6 graphviz==0.20.1 -bitsandbytes==0.37.2 \ No newline at end of file diff --git a/run.py b/run.py index 0773528..5bc1aba 100644 --- a/run.py +++ b/run.py @@ -1,4 +1,5 @@ import matplotlib + matplotlib.use('Agg') import os, sys @@ -19,17 +20,14 @@ from reconstruction import reconstruction import os from torchinfo import summary -import bitsandbytes as bnb optimizer_choices = { 'adam': torch.optim.Adam, 'adamw': torch.optim.AdamW, - 'adam8bit': bnb.optim.Adam8bit, - "adamw8bit": bnb.optim.AdamW8bit, } if __name__ == "__main__": - + if sys.version_info[0] < 3: raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9") @@ -40,15 +38,14 @@ parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore") parser.add_argument("--kp_detector", default=None, help="path to kp_detector checkpoint to restore") parser.add_argument("--bg_predictor", default=None, help="path to bg_predictor checkpoint to restore") - parser.add_argument("--dense_motion_network", default=None, help="path to dense_motion_network checkpoint to restore") + parser.add_argument("--dense_motion_network", default=None, + help="path to dense_motion_network checkpoint to restore") parser.add_argument("--inpainting", default=None, help="path to inpainting checkpoint to restore") parser.add_argument("--detect_anomaly", action="store_true", help="detect anomaly in autograd") - - opt = parser.parse_args() with open(opt.config) as f: - config = yaml.load(f) + config = yaml.load(f, Loader=yaml.FullLoader) if opt.checkpoint is not None: log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1]) @@ -60,14 +57,12 @@ torch.autograd.set_detect_anomaly(True) inpainting = InpaintingNetwork(**config['model_params']['generator_params'], - **config['model_params']['common_params']) - + **config['model_params']['common_params']) kp_detector = KPDetector(**config['model_params']['common_params']) dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'], **config['model_params']['dense_motion_params']) - bg_predictor = None if (config['model_params']['common_params']['bg']): bg_predictor = BGMotionPredictor() @@ -75,7 +70,7 @@ avd_network = None if opt.mode == "train_avd": avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'], - **config['model_params']['avd_network_params']) + **config['model_params']['avd_network_params']) dataset = FramesDataset(is_train=(opt.mode.startswith('train')), **config['dataset_params']) print("Dataset length: ", len(dataset)) @@ -104,15 +99,15 @@ optimizer_class=optimizer_class, kp_detector_checkpoint=opt.kp_detector, bg_predictor_checkpoint=opt.bg_predictor, - dense_motion_network_checkpoint=opt.dense_motion_network, + dense_motion_network_checkpoint=opt.dense_motion_network, inpainting_checkpoint=opt.inpainting - ) + ) elif opt.mode == 'train_avd': print("Training Animation via Disentaglement...") train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, log_dir, dataset, optimizer_class=optimizer_class) elif opt.mode == 'reconstruction': print("Reconstruction...") - #TODO: update to accelerate - reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset) - + # TODO: update to accelerate + reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, + dataset) From 098138724c549de3bc2902078f334d634b5d7824 Mon Sep 17 00:00:00 2001 From: Philipp Haslbauer Date: Tue, 12 Sep 2023 20:41:43 +0200 Subject: [PATCH 23/30] add another config file --- config/vox-512-deeper-other.yaml | 98 ++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 config/vox-512-deeper-other.yaml diff --git a/config/vox-512-deeper-other.yaml b/config/vox-512-deeper-other.yaml new file mode 100644 index 0000000..5cb7202 --- /dev/null +++ b/config/vox-512-deeper-other.yaml @@ -0,0 +1,98 @@ +name: vox-256-deeper-other + +dataset_params: + root_dir: ../data/vox512_filtered_webp + frame_shape: 512,512,3 + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 4 + concat_encode: True + skip_block_type: depthwise + dropout: 0.1 + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 # might make sense to set to 0.5 because of the additional occlusion (4=>5) + occlusion_num: 5 + + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + +train_params: + num_epochs: 30 + num_repeats: 5 + lr_generator: 2.0e-5 + lr_discriminator: 2.0e-5 + batch_size: 4 + scales: [1, 0.5, 0.25, 0.125] + dataloader_workers: 8 + checkpoint_freq: 10 + dropout_epoch: 0 + dropout_maxp: 0.3 + dropout_startp: 0.1 + dropout_inc_epoch: 10 + bg_start: 101 + freeze_kp_detector: True + freeze_bg_predictor: True + freeze_dense_motion: False + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + warp_loss: 10 + bg: 0 + l2: 0 + id: 0.1 + huber: 0 + generator_gan: 1 + generator_feat_match: 0 + discriminator_gan: 1 + optimizer: 'adamw' + optimizer_params: + betas: [ 0.9, 0.999 ] + weight_decay: 1.0e-3 + scheduler: 'onecycle' + scheduler_params: + pct_start: 0.3 + + +train_avd_params: + num_epochs: 100 + num_repeats: 1 + batch_size: 8 + dataloader_workers: 6 + checkpoint_freq: 1 + epoch_milestones: [10, 20] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow' \ No newline at end of file From 39fd117adf3bc3c552699f2f241260d6bb210ae9 Mon Sep 17 00:00:00 2001 From: Philipp Haslbauer Date: Tue, 12 Sep 2023 20:44:06 +0200 Subject: [PATCH 24/30] update config file --- config/vox-512-deeper-other.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/vox-512-deeper-other.yaml b/config/vox-512-deeper-other.yaml index 5cb7202..0857a2c 100644 --- a/config/vox-512-deeper-other.yaml +++ b/config/vox-512-deeper-other.yaml @@ -67,7 +67,7 @@ train_params: warp_loss: 10 bg: 0 l2: 0 - id: 0.1 + id: 1 huber: 0 generator_gan: 1 generator_feat_match: 0 From f63f938555ac73445b99d18cc182c77de9d45b89 Mon Sep 17 00:00:00 2001 From: Philipp Haslbauer Date: Wed, 13 Sep 2023 12:36:20 +0200 Subject: [PATCH 25/30] fix import --- save_model_only.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/save_model_only.py b/save_model_only.py index faedc21..cbec5f8 100644 --- a/save_model_only.py +++ b/save_model_only.py @@ -3,7 +3,7 @@ import yaml -from gan import MultiScaleDiscriminator +from modules.gan import MultiScaleDiscriminator from modules.inpainting_network import InpaintingNetwork from modules.keypoint_detector import KPDetector from modules.bg_motion_predictor import BGMotionPredictor From 122f95eeb0c4dbf1e377d4e5b725826752ccc314 Mon Sep 17 00:00:00 2001 From: Philipp Haslbauer Date: Wed, 13 Sep 2023 12:37:00 +0200 Subject: [PATCH 26/30] fix yaml loading --- save_model_only.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/save_model_only.py b/save_model_only.py index cbec5f8..1b5f0d8 100644 --- a/save_model_only.py +++ b/save_model_only.py @@ -22,7 +22,7 @@ opt = parser.parse_args() with open(opt.config) as f: - config = yaml.load(f) + config = yaml.load(f, Loader=yaml.FullLoader) checkpoint = torch.load(opt.checkpoint) print(checkpoint.keys()) From 23329fcd7c890af48ba91a5842d1f5a96089ca2b Mon Sep 17 00:00:00 2001 From: Philipp Haslbauer Date: Wed, 13 Sep 2023 12:39:53 +0200 Subject: [PATCH 27/30] reduce batch size --- config/vox-512-deeper-other.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/vox-512-deeper-other.yaml b/config/vox-512-deeper-other.yaml index 0857a2c..794767b 100644 --- a/config/vox-512-deeper-other.yaml +++ b/config/vox-512-deeper-other.yaml @@ -1,4 +1,4 @@ -name: vox-256-deeper-other +name: vox-512-deeper-other dataset_params: root_dir: ../data/vox512_filtered_webp @@ -45,7 +45,7 @@ train_params: num_repeats: 5 lr_generator: 2.0e-5 lr_discriminator: 2.0e-5 - batch_size: 4 + batch_size: 2 scales: [1, 0.5, 0.25, 0.125] dataloader_workers: 8 checkpoint_freq: 10 From 92da8baf80900beb8a0f9b1323cc9ded10255e0f Mon Sep 17 00:00:00 2001 From: Philipp Haslbauer Date: Wed, 13 Sep 2023 20:26:08 +0200 Subject: [PATCH 28/30] fix logger --- logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/logger.py b/logger.py index 7b498f3..922178e 100644 --- a/logger.py +++ b/logger.py @@ -56,7 +56,7 @@ def visualize_rec(self, inp, out): def save_cpk(self, emergent=False): - cpk = {k: v.state_dict() for k, v in self.models.items()} + cpk = {k: v.state_dict() for k, v in self.models.items() if v is not None} cpk['epoch'] = self.epoch cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num)) if not (os.path.exists(cpk_path) and emergent): From 989c9422a5b4e40922423dfe270fe15cf5f56501 Mon Sep 17 00:00:00 2001 From: Philipp Haslbauer Date: Thu, 21 Sep 2023 11:31:39 +0200 Subject: [PATCH 29/30] 1024 finetune config --- config/vox-1024-finetune.yaml | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/config/vox-1024-finetune.yaml b/config/vox-1024-finetune.yaml index a2a955d..2f2c3ce 100644 --- a/config/vox-1024-finetune.yaml +++ b/config/vox-1024-finetune.yaml @@ -1,4 +1,5 @@ # Use this file to finetune from a pretrained 256x256 model +name: vox-1024-finetune dataset_params: root_dir: ./video-preprocessing/vox2-768 frame_shape: 1024,1024,3 @@ -35,18 +36,20 @@ model_params: train_params: - num_epochs: 5 - num_repeats: 4 + visualize_model: False + num_epochs: 50 + num_repeats: 1 # Higher LR seems to bring problems when finetuning - lr_generator: 2.0e-5 + lr_generator: 2.0e-6 + lr_discriminator: 2.0e-5 batch_size: 1 - scales: [1, 0.5, 0.25, 0.125, 0.0625, 0.03125] - dataloader_workers: 6 - checkpoint_freq: 5 - dropout_epoch: 2 + scales: [1, 0.5, 0.25, 0.125, 0.0625] + dataloader_workers: 8 + checkpoint_freq: 1 + dropout_epoch: 0 dropout_maxp: 0.3 dropout_startp: 0.1 - dropout_inc_epoch: 1 + dropout_inc_epoch: 0 bg_start: 81 freeze_kp_detector: True freeze_bg_predictor: True @@ -59,11 +62,19 @@ train_params: equivariance_value: 10 warp_loss: 10 bg: 10 + id: 0.1 + l2: 0 + huber: 0 + generator_gan: 0 + generator_feat_match: 0 + discriminator_gan: 0 optimizer: 'adamw' optimizer_params: betas: [ 0.9, 0.999 ] weight_decay: 0.1 - + scheduler: 'onecycle' + scheduler_params: + pct_start: 0.01 train_avd_params: num_epochs: 200 From 3179ac90d492a0e5f5d25a8bc02dd560baa40b90 Mon Sep 17 00:00:00 2001 From: TGG Date: Thu, 21 Sep 2023 11:50:54 +0200 Subject: [PATCH 30/30] add 1536 config file --- config/vox-1536-finetune.yaml | 93 +++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 config/vox-1536-finetune.yaml diff --git a/config/vox-1536-finetune.yaml b/config/vox-1536-finetune.yaml new file mode 100644 index 0000000..9a06d61 --- /dev/null +++ b/config/vox-1536-finetune.yaml @@ -0,0 +1,93 @@ +# Use this file to finetune from a pretrained 256x256 model +name: vox-1536-finetune +dataset_params: + root_dir: ./video-preprocessing/vox2-768 + frame_shape: 1536,1536,3 + id_sampling: True + augmentation_params: + flip_param: + horizontal_flip: True + time_flip: True + jitter_param: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.1 + + +model_params: + common_params: + num_tps: 10 + num_channels: 3 + bg: True + multi_mask: True + generator_params: + block_expansion: 64 + max_features: 512 + num_down_blocks: 3 + dense_motion_params: + block_expansion: 64 + max_features: 1024 + num_blocks: 5 + scale_factor: 0.25 + avd_network_params: + id_bottle_size: 128 + pose_bottle_size: 128 + + +train_params: + visualize_model: False + num_epochs: 50 + num_repeats: 1 + # Higher LR seems to bring problems when finetuning + lr_generator: 2.0e-6 + lr_discriminator: 2.0e-5 + batch_size: 1 + scales: [1, 0.5, 0.25, 0.125, 0.0625] + dataloader_workers: 8 + checkpoint_freq: 1 + dropout_epoch: 0 + dropout_maxp: 0.3 + dropout_startp: 0.1 + dropout_inc_epoch: 0 + bg_start: 81 + freeze_kp_detector: True + freeze_bg_predictor: True + transform_params: + sigma_affine: 0.05 + sigma_tps: 0.005 + points_tps: 5 + loss_weights: + perceptual: [10, 10, 10, 10, 10] + equivariance_value: 10 + warp_loss: 10 + bg: 10 + id: 0.1 + l2: 0 + huber: 0 + generator_gan: 0 + generator_feat_match: 0 + discriminator_gan: 0 + optimizer: 'adamw' + optimizer_params: + betas: [ 0.9, 0.999 ] + weight_decay: 0.1 + scheduler: 'onecycle' + scheduler_params: + pct_start: 0.01 + +train_avd_params: + num_epochs: 200 + num_repeats: 1 + batch_size: 1 + dataloader_workers: 6 + checkpoint_freq: 1 + epoch_milestones: [140, 180] + lr: 1.0e-3 + lambda_shift: 1 + random_scale: 0.25 + +visualizer_params: + kp_size: 5 + draw_border: True + colormap: 'gist_rainbow'