1- '''
1+ """
22Modules for the SPAGHETTI model to translate microscopy images to H&E images
3- '''
3+ """
44import torch .nn as nn
55import torch .nn .functional as F
66import torch
77
8+
89class ResidualBlock (nn .Module ):
910 def __init__ (self , in_channels ):
1011 super (ResidualBlock , self ).__init__ ()
1112 self .block = nn .Sequential (
12- nn .ReflectionPad2d (1 ), # padding, keep the image size constant after next conv2d
13+ nn .ReflectionPad2d (1 ), # padding, keep the image size constant after next conv2d
1314 nn .Conv2d (in_channels , in_channels , 3 ),
1415 nn .InstanceNorm2d (in_channels ),
1516 nn .ReLU (inplace = True ),
1617 nn .ReflectionPad2d (1 ),
1718 nn .Conv2d (in_channels , in_channels , 3 ),
1819 nn .InstanceNorm2d (in_channels )
1920 )
20-
21+
2122 def forward (self , x ):
2223 return x + self .block (x )
23-
24+
25+
2426class GeneratorResNet (nn .Module ):
2527 def __init__ (self , in_channels , num_residual_blocks = 9 ):
2628 super (GeneratorResNet , self ).__init__ ()
27-
29+
2830 # Inital Convolution 3*256*256 -> 64*256*256
29- out_channels = 64
31+ out_channels = 64
3032 self .conv = nn .Sequential (
31- nn .ReflectionPad2d (in_channels ), # padding, keep the image size constant after next conv2d
33+ nn .ReflectionPad2d (in_channels ), # padding, keep the image size constant after next conv2d
3234 nn .Conv2d (in_channels , out_channels , 2 * in_channels + 1 ),
3335 nn .InstanceNorm2d (out_channels ),
3436 nn .ReLU (inplace = True ),
3537 )
36-
38+
3739 channels = out_channels
38-
40+
3941 # Downsampling 64*256*256 -> 128*128*128 -> 256*64*64
4042 self .down = []
4143 for _ in range (2 ):
@@ -47,31 +49,31 @@ def __init__(self, in_channels, num_residual_blocks=9):
4749 ]
4850 channels = out_channels
4951 self .down = nn .Sequential (* self .down )
50-
52+
5153 # Transformation (ResNet) 256*64*64
5254 self .trans = [ResidualBlock (channels ) for _ in range (num_residual_blocks )]
5355 self .trans = nn .Sequential (* self .trans )
54-
56+
5557 # Upsampling 256*64*64 -> 128*128*128 -> 64*256*256
5658 self .up = []
5759 for _ in range (2 ):
5860 out_channels = channels // 2
5961 self .up += [
60- nn .Upsample (scale_factor = 2 ), # bilinear interpolation
62+ nn .Upsample (scale_factor = 2 ), # bilinear interpolation
6163 nn .Conv2d (channels , out_channels , 3 , stride = 1 , padding = 1 ),
6264 nn .InstanceNorm2d (out_channels ),
6365 nn .ReLU (inplace = True ),
6466 ]
6567 channels = out_channels
6668 self .up = nn .Sequential (* self .up )
67-
69+
6870 # Out layer 64*256*256 -> 3*256*256
6971 self .out = nn .Sequential (
7072 nn .ReflectionPad2d (in_channels ),
7173 nn .Conv2d (channels , in_channels , 2 * in_channels + 1 ),
7274 nn .Tanh ()
7375 )
74-
76+
7577 def forward (self , x ):
7678 x = self .conv (x )
7779 x = self .down (x )
@@ -80,42 +82,44 @@ def forward(self, x):
8082 x = self .out (x )
8183 return x
8284
85+
8386class Discriminator (nn .Module ):
8487 def __init__ (self , in_channels ):
8588 super (Discriminator , self ).__init__ ()
86-
89+
8790 self .model = nn .Sequential (
88- * self .block (in_channels , 64 , normalize = False ), # 3*256*256 -> 64*128*128
91+ * self .block (in_channels , 64 , normalize = False ), # 3*256*256 -> 64*128*128
8992 * self .block (64 , 128 ), # 64*128*128 -> 128*64*64
90- * self .block (128 , 256 ), # 128*64*64 -> 256*32*32
91- * self .block (256 , 512 ), # 256*32*32 -> 512*16*16
92-
93- nn .ZeroPad2d ((1 ,0 , 1 , 0 )), # padding left and top 512*16*16 -> 512*17*17
94- nn .Conv2d (512 , 1 , 4 , padding = 1 ) # 512*17*17 -> 1*16*16
93+ * self .block (128 , 256 ), # 128*64*64 -> 256*32*32
94+ * self .block (256 , 512 ), # 256*32*32 -> 512*16*16
95+
96+ nn .ZeroPad2d ((1 , 0 , 1 , 0 )), # padding left and top 512*16*16 -> 512*17*17
97+ nn .Conv2d (512 , 1 , 4 , padding = 1 ) # 512*17*17 -> 1*16*16
9598 )
96-
99+
97100 self .scale_factor = 16
98-
101+
99102 @staticmethod
100103 def block (in_channels , out_channels , normalize = True ):
101104 layers = [nn .Conv2d (in_channels , out_channels , 4 , stride = 2 , padding = 1 )]
102105 if normalize :
103106 layers .append (nn .InstanceNorm2d (out_channels ))
104107 layers .append (nn .LeakyReLU (0.2 , inplace = True ))
105-
108+
106109 return layers
107-
110+
108111 def forward (self , x ):
109112 return self .model (x )
110113
114+
111115class SSIMLoss (nn .Module ):
112116 def __init__ (self , window_size = 11 , size_average = True ):
113117 super (SSIMLoss , self ).__init__ ()
114118 self .window_size = window_size
115119 self .size_average = size_average
116120 self .channel = 1
117121 self .create_window (window_size )
118-
122+
119123 def create_window (self , window_size , channel = 1 ):
120124 # Create a Gaussian window (filter) with specified size and channel
121125 sigma = 1.5
@@ -131,14 +135,14 @@ def forward(self, img1, img2):
131135 if img1 .size (1 ) != self .channel :
132136 self .channel = img1 .size (1 )
133137 self .create_window (self .window_size , self .channel )
134-
138+
135139 # Move window to the same device as the images
136140 window = self .window .to (img1 .device )
137-
141+
138142 # Compute SSIM components
139143 mu1 = F .conv2d (img1 , window , padding = self .window_size // 2 , groups = self .channel )
140144 mu2 = F .conv2d (img2 , window , padding = self .window_size // 2 , groups = self .channel )
141-
145+
142146 mu1_sq = mu1 .pow (2 )
143147 mu2_sq = mu2 .pow (2 )
144148 mu1_mu2 = mu1 * mu2
@@ -153,7 +157,7 @@ def forward(self, img1, img2):
153157
154158 # SSIM calculation
155159 ssim_map = ((2 * mu1_mu2 + C1 ) * (2 * sigma12 + C2 )) / ((mu1_sq + mu2_sq + C1 ) * (sigma1_sq + sigma2_sq + C2 ))
156-
160+
157161 if self .size_average :
158162 return torch .clamp ((1 - ssim_map .mean ()) / 2 , 0 , 1 ) # SSIM loss as (1 - SSIM) / 2
159163 else :
0 commit comments