@@ -48,14 +48,25 @@ def calculate_loss_generator(self, res, x1, x2):
4848 x1 .size (3 )// self .D_A .scale_factor ]
4949 valid = torch .ones (out_shape ).to (self .device )
5050 fake_x1 , fake_x2 , recov_x1 , recov_x2 , new_x1 , new_x2 = res
51+ # convert to gray scale for ssim calculation
52+ gray = Grayscale (num_output_channels = 3 )
53+ x1_gray = gray (x1 )
54+ x2_gray = gray (x2 )
55+ fake_x1_gray = gray (fake_x1 )
56+ fake_x2_gray = gray (fake_x2 )
57+ new_x1_gray = gray (new_x1 )
58+ new_x2_gray = gray (new_x2 )
5159 loss_GAN = (self .criterion_GAN (self .D_A (fake_x1 ), valid )
5260 + self .criterion_GAN (self .D_B (fake_x2 ), valid ))/ 2
5361 loss_cycle = (self .criterion_cycle (recov_x1 , x1 )
5462 + self .criterion_cycle (recov_x2 , x2 ))/ 2
5563 loss_identity = (self .criterion_identity (new_x1 , x1 )
5664 + self .criterion_identity (new_x2 , x2 ))/ 2
57- loss_ssim = (self .criterion_ssim (new_x1 , x1 )
58- + self .criterion_ssim (new_x2 , x2 )) / 2
65+ loss_ssim_fake = (self .criterion_ssim (fake_x2_gray , x1_gray )
66+ + self .criterion_ssim (fake_x1_gray , x2_gray )) / 2
67+ loss_ssim_real = (self .criterion_ssim (new_x1_gray , x1_gray )
68+ + self .criterion_ssim (new_x2_gray , x2_gray )) / 2
69+ loss_ssim = (loss_ssim_fake + loss_ssim_real ) / 2
5970 total_loss = (self .weights [0 ] * loss_GAN + self .weights [1 ] * loss_cycle
6071 + self .weights [2 ] * loss_identity + self .weights [3 ] * loss_ssim )
6172 return total_loss
0 commit comments