|
| 1 | +import time |
1 | 2 | import torch |
2 | 3 |
|
3 | | -from network import AvatarNet |
4 | | -from utils import ImageFolder, get_transformer, imsave |
5 | | -from loss import LossCalculator |
| 4 | +from network import AvatarNet, Encoder |
| 5 | +from utils import ImageFolder, imsave, lastest_arverage_value |
6 | 6 |
|
7 | 7 | def network_train(args): |
8 | 8 | # set device |
9 | | - device = torch.device("cuda" if args.cuda_device_no >= 0 else "cpu") |
10 | | - |
11 | | - # save arguments |
12 | | - torch.save(args, args.save_path+"arguments.pth") |
| 9 | + device = torch.device('cuda' if args.gpu_no >= 0 else 'cpu') |
13 | 10 |
|
14 | 11 | # get network |
15 | | - network = AvatarNet(args.layers) |
16 | | - network = network.to(device) |
| 12 | + network = AvatarNet(args.layers).to(device) |
17 | 13 |
|
18 | 14 | # get data set |
19 | | - data_set = ImageFolder(args.train_data_path, get_transformer(args.imsize, args.cropsize)) |
| 15 | + data_set = ImageFolder(args.content_dir, args.imsize, args.cropsize, args.cencrop) |
20 | 16 |
|
21 | 17 | # get loss calculator |
22 | | - loss_calculator = LossCalculator(device, args.layers, args.feature_weight, args.reconstruction_weight, args.tv_weight) |
| 18 | + loss_network = Encoder(args.layers).to(device) |
| 19 | + mse_loss = torch.nn.MSELoss(reduction='mean').to(device) |
| 20 | + loss_seq = {'total':[], 'image':[], 'feature':[], 'tv':[]} |
23 | 21 |
|
24 | 22 | # get optimizer |
25 | | - optimizer = torch.optim.Adam(network.decoders.parameters(), lr=args.lr) |
| 23 | + for param in network.encoder.parameters(): |
| 24 | + param.requires_grad = False |
| 25 | + optimizer = torch.optim.Adam(network.decoder.parameters(), lr=args.lr) |
26 | 26 |
|
27 | 27 | # training |
28 | 28 | for iteration in range(args.max_iter): |
29 | 29 | data_loader = torch.utils.data.DataLoader(data_set, batch_size=args.batch_size, shuffle=True) |
30 | | - image = next(iter(data_loader)).to(device) |
| 30 | + input_image = next(iter(data_loader)).to(device) |
| 31 | + |
| 32 | + output_image = network(input_image, [input_image], train=True) |
| 33 | + |
| 34 | + # calculate losses |
| 35 | + total_loss = 0 |
| 36 | + ## image reconstruction loss |
| 37 | + image_loss = mse_loss(output_image, input_image) |
| 38 | + loss_seq['image'].append(image_loss.item()) |
| 39 | + total_loss += image_loss |
31 | 40 |
|
32 | | - output = network(image, image, train_flag=True) |
| 41 | + ## feature reconstruction loss |
| 42 | + input_features = loss_network(input_image) |
| 43 | + output_features = loss_network(output_image) |
| 44 | + feature_loss = 0 |
| 45 | + for output_feature, input_feature in zip(output_features, input_features): |
| 46 | + feature_loss += mse_loss(output_feature, input_feature) |
| 47 | + loss_seq['feature'].append(feature_loss.item()) |
| 48 | + total_loss += feature_loss * args.feature_weight |
33 | 49 |
|
34 | | - total_loss = loss_calculator.calc_total_loss(output, image) |
| 50 | + ## total variation loss |
| 51 | + tv_loss = calc_tv_loss(output_image) |
| 52 | + loss_seq['tv'].append(tv_loss.item()) |
| 53 | + total_loss += tv_loss * args.tv_weight |
| 54 | + |
| 55 | + loss_seq['total'].append(total_loss.item()) |
35 | 56 |
|
36 | 57 | optimizer.zero_grad() |
37 | 58 | total_loss.backward() |
38 | 59 | optimizer.step() |
39 | 60 |
|
40 | 61 | # print loss log and save network, loss log and output images |
41 | 62 | if (iteration + 1) % args.check_iter == 0: |
42 | | - loss_calculator.print_loss_seq() |
43 | | - torch.save(network.state_dict(), args.save_path+"network.pth") |
44 | | - torch.save(loss_calculator.loss_seq, args.save_path+"loss_seq.pth") |
45 | | - imsave(output, args.save_path+"training_image.png") |
| 63 | + imsave(torch.cat([input_image, output_image], dim=0), args.save_path+"training_image.png") |
| 64 | + print("%s: Iteration: [%d/%d]\tImage Loss: %2.4f\tFeature Loss: %2.4f\tTV Loss: %2.4f\tTotal: %2.4f"%(time.ctime(),iteration+1, |
| 65 | + args.max_iter, lastest_arverage_value(loss_seq['image']), lastest_arverage_value(loss_seq['feature']), |
| 66 | + lastest_arverage_value(loss_seq['tv']), lastest_arverage_value(loss_seq['total']))) |
| 67 | + torch.save({'iteration': iteration+1, |
| 68 | + 'state_dict': network.state_dict(), |
| 69 | + 'loss_seq': loss_seq}, |
| 70 | + args.save_path+'check_point.pth') |
46 | 71 |
|
47 | 72 | return network |
| 73 | + |
| 74 | +def calc_tv_loss(x): |
| 75 | + tv_loss = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) |
| 76 | + tv_loss += torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) |
| 77 | + return tv_loss |
| 78 | + |
0 commit comments