Skip to content

Commit e51f155

Browse files
committed
update for calculating loss
1 parent 6d0306f commit e51f155

File tree

1 file changed

+50
-19
lines changed

1 file changed

+50
-19
lines changed

train.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,78 @@
1+
import time
12
import torch
23

3-
from network import AvatarNet
4-
from utils import ImageFolder, get_transformer, imsave
5-
from loss import LossCalculator
4+
from network import AvatarNet, Encoder
5+
from utils import ImageFolder, imsave, lastest_arverage_value
66

77
def network_train(args):
88
# set device
9-
device = torch.device("cuda" if args.cuda_device_no >= 0 else "cpu")
10-
11-
# save arguments
12-
torch.save(args, args.save_path+"arguments.pth")
9+
device = torch.device('cuda' if args.gpu_no >= 0 else 'cpu')
1310

1411
# get network
15-
network = AvatarNet(args.layers)
16-
network = network.to(device)
12+
network = AvatarNet(args.layers).to(device)
1713

1814
# get data set
19-
data_set = ImageFolder(args.train_data_path, get_transformer(args.imsize, args.cropsize))
15+
data_set = ImageFolder(args.content_dir, args.imsize, args.cropsize, args.cencrop)
2016

2117
# get loss calculator
22-
loss_calculator = LossCalculator(device, args.layers, args.feature_weight, args.reconstruction_weight, args.tv_weight)
18+
loss_network = Encoder(args.layers).to(device)
19+
mse_loss = torch.nn.MSELoss(reduction='mean').to(device)
20+
loss_seq = {'total':[], 'image':[], 'feature':[], 'tv':[]}
2321

2422
# get optimizer
25-
optimizer = torch.optim.Adam(network.decoders.parameters(), lr=args.lr)
23+
for param in network.encoder.parameters():
24+
param.requires_grad = False
25+
optimizer = torch.optim.Adam(network.decoder.parameters(), lr=args.lr)
2626

2727
# training
2828
for iteration in range(args.max_iter):
2929
data_loader = torch.utils.data.DataLoader(data_set, batch_size=args.batch_size, shuffle=True)
30-
image = next(iter(data_loader)).to(device)
30+
input_image = next(iter(data_loader)).to(device)
31+
32+
output_image = network(input_image, [input_image], train=True)
33+
34+
# calculate losses
35+
total_loss = 0
36+
## image reconstruction loss
37+
image_loss = mse_loss(output_image, input_image)
38+
loss_seq['image'].append(image_loss.item())
39+
total_loss += image_loss
3140

32-
output = network(image, image, train_flag=True)
41+
## feature reconstruction loss
42+
input_features = loss_network(input_image)
43+
output_features = loss_network(output_image)
44+
feature_loss = 0
45+
for output_feature, input_feature in zip(output_features, input_features):
46+
feature_loss += mse_loss(output_feature, input_feature)
47+
loss_seq['feature'].append(feature_loss.item())
48+
total_loss += feature_loss * args.feature_weight
3349

34-
total_loss = loss_calculator.calc_total_loss(output, image)
50+
## total variation loss
51+
tv_loss = calc_tv_loss(output_image)
52+
loss_seq['tv'].append(tv_loss.item())
53+
total_loss += tv_loss * args.tv_weight
54+
55+
loss_seq['total'].append(total_loss.item())
3556

3657
optimizer.zero_grad()
3758
total_loss.backward()
3859
optimizer.step()
3960

4061
# print loss log and save network, loss log and output images
4162
if (iteration + 1) % args.check_iter == 0:
42-
loss_calculator.print_loss_seq()
43-
torch.save(network.state_dict(), args.save_path+"network.pth")
44-
torch.save(loss_calculator.loss_seq, args.save_path+"loss_seq.pth")
45-
imsave(output, args.save_path+"training_image.png")
63+
imsave(torch.cat([input_image, output_image], dim=0), args.save_path+"training_image.png")
64+
print("%s: Iteration: [%d/%d]\tImage Loss: %2.4f\tFeature Loss: %2.4f\tTV Loss: %2.4f\tTotal: %2.4f"%(time.ctime(),iteration+1,
65+
args.max_iter, lastest_arverage_value(loss_seq['image']), lastest_arverage_value(loss_seq['feature']),
66+
lastest_arverage_value(loss_seq['tv']), lastest_arverage_value(loss_seq['total'])))
67+
torch.save({'iteration': iteration+1,
68+
'state_dict': network.state_dict(),
69+
'loss_seq': loss_seq},
70+
args.save_path+'check_point.pth')
4671

4772
return network
73+
74+
def calc_tv_loss(x):
75+
tv_loss = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
76+
tv_loss += torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
77+
return tv_loss
78+

0 commit comments

Comments
 (0)