Skip to content

Commit 2128c41

Browse files
committed
update spaghetti training loss calculation
1 parent 6483014 commit 2128c41

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="pcm-spaghetti",
5-
version="1.0.2",
5+
version="1.1",
66
author="Richard (Zhi Fei) Dong, Chris McIntosh, Gregory W. Schwartz",
77
author_email="gregory.schwartz@uhn.ca",
88
description="A PyTorch implementation of the SPAGHETTI model for phase-contrast microscopy image transformation",

spaghetti/train.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,25 @@ def calculate_loss_generator(self, res, x1, x2):
4848
x1.size(3)//self.D_A.scale_factor]
4949
valid = torch.ones(out_shape).to(self.device)
5050
fake_x1, fake_x2, recov_x1, recov_x2, new_x1, new_x2 = res
51+
# convert to gray scale for ssim calculation
52+
gray = Grayscale(num_output_channels=3)
53+
x1_gray = gray(x1)
54+
x2_gray = gray(x2)
55+
fake_x1_gray = gray(fake_x1)
56+
fake_x2_gray = gray(fake_x2)
57+
new_x1_gray = gray(new_x1)
58+
new_x2_gray = gray(new_x2)
5159
loss_GAN = (self.criterion_GAN(self.D_A(fake_x1), valid)
5260
+ self.criterion_GAN(self.D_B(fake_x2), valid))/2
5361
loss_cycle = (self.criterion_cycle(recov_x1, x1)
5462
+ self.criterion_cycle(recov_x2, x2))/2
5563
loss_identity = (self.criterion_identity(new_x1, x1)
5664
+ self.criterion_identity(new_x2, x2))/2
57-
loss_ssim = (self.criterion_ssim(new_x1, x1)
58-
+ self.criterion_ssim(new_x2, x2)) / 2
65+
loss_ssim_fake = (self.criterion_ssim(fake_x2_gray, x1_gray)
66+
+ self.criterion_ssim(fake_x1_gray, x2_gray)) / 2
67+
loss_ssim_real = (self.criterion_ssim(new_x1_gray, x1_gray)
68+
+ self.criterion_ssim(new_x2_gray, x2_gray)) / 2
69+
loss_ssim = (loss_ssim_fake + loss_ssim_real) / 2
5970
total_loss = (self.weights[0] * loss_GAN + self.weights[1] * loss_cycle
6071
+ self.weights[2] * loss_identity + self.weights[3] * loss_ssim)
6172
return total_loss

0 commit comments

Comments
 (0)