Skip to content

Commit 932fcb0

Browse files
authored
Merge pull request #3 from schwartzlab-methods/copilot/code-review-for-spaghetti-api
Fix critical bug and improve code quality for SPAGHETTI API
2 parents d756cbe + ae296df commit 932fcb0

File tree

7 files changed

+154
-116
lines changed

7 files changed

+154
-116
lines changed

spaghetti/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
SPAGHETTI - Structural Phase Adaptation via Generative Histological Enhancement
3+
and Texture-preserving Translation Integration
4+
5+
A PyTorch implementation for phase-contrast microscopy image transformation.
6+
"""
7+
8+
from spaghetti.inferences import Spaghetti
9+
from spaghetti.dataset import TrainingDataset
10+
from spaghetti.train import train_spaghetti
11+
from spaghetti import utils
12+
13+
__all__ = [
14+
"Spaghetti",
15+
"TrainingDataset",
16+
"train_spaghetti",
17+
"utils",
18+
]

spaghetti/_spaghetti_modules.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,43 @@
1-
'''
1+
"""
22
Modules for the SPAGHETTI model to translate microscopy images to H&E images
3-
'''
3+
"""
44
import torch.nn as nn
55
import torch.nn.functional as F
66
import torch
77

8+
89
class ResidualBlock(nn.Module):
910
def __init__(self, in_channels):
1011
super(ResidualBlock, self).__init__()
1112
self.block = nn.Sequential(
12-
nn.ReflectionPad2d(1), # padding, keep the image size constant after next conv2d
13+
nn.ReflectionPad2d(1), # padding, keep the image size constant after next conv2d
1314
nn.Conv2d(in_channels, in_channels, 3),
1415
nn.InstanceNorm2d(in_channels),
1516
nn.ReLU(inplace=True),
1617
nn.ReflectionPad2d(1),
1718
nn.Conv2d(in_channels, in_channels, 3),
1819
nn.InstanceNorm2d(in_channels)
1920
)
20-
21+
2122
def forward(self, x):
2223
return x + self.block(x)
23-
24+
25+
2426
class GeneratorResNet(nn.Module):
2527
def __init__(self, in_channels, num_residual_blocks=9):
2628
super(GeneratorResNet, self).__init__()
27-
29+
2830
# Inital Convolution 3*256*256 -> 64*256*256
29-
out_channels=64
31+
out_channels = 64
3032
self.conv = nn.Sequential(
31-
nn.ReflectionPad2d(in_channels), # padding, keep the image size constant after next conv2d
33+
nn.ReflectionPad2d(in_channels), # padding, keep the image size constant after next conv2d
3234
nn.Conv2d(in_channels, out_channels, 2*in_channels+1),
3335
nn.InstanceNorm2d(out_channels),
3436
nn.ReLU(inplace=True),
3537
)
36-
38+
3739
channels = out_channels
38-
40+
3941
# Downsampling 64*256*256 -> 128*128*128 -> 256*64*64
4042
self.down = []
4143
for _ in range(2):
@@ -47,31 +49,31 @@ def __init__(self, in_channels, num_residual_blocks=9):
4749
]
4850
channels = out_channels
4951
self.down = nn.Sequential(*self.down)
50-
52+
5153
# Transformation (ResNet) 256*64*64
5254
self.trans = [ResidualBlock(channels) for _ in range(num_residual_blocks)]
5355
self.trans = nn.Sequential(*self.trans)
54-
56+
5557
# Upsampling 256*64*64 -> 128*128*128 -> 64*256*256
5658
self.up = []
5759
for _ in range(2):
5860
out_channels = channels // 2
5961
self.up += [
60-
nn.Upsample(scale_factor=2), # bilinear interpolation
62+
nn.Upsample(scale_factor=2), # bilinear interpolation
6163
nn.Conv2d(channels, out_channels, 3, stride=1, padding=1),
6264
nn.InstanceNorm2d(out_channels),
6365
nn.ReLU(inplace=True),
6466
]
6567
channels = out_channels
6668
self.up = nn.Sequential(*self.up)
67-
69+
6870
# Out layer 64*256*256 -> 3*256*256
6971
self.out = nn.Sequential(
7072
nn.ReflectionPad2d(in_channels),
7173
nn.Conv2d(channels, in_channels, 2*in_channels+1),
7274
nn.Tanh()
7375
)
74-
76+
7577
def forward(self, x):
7678
x = self.conv(x)
7779
x = self.down(x)
@@ -80,42 +82,44 @@ def forward(self, x):
8082
x = self.out(x)
8183
return x
8284

85+
8386
class Discriminator(nn.Module):
8487
def __init__(self, in_channels):
8588
super(Discriminator, self).__init__()
86-
89+
8790
self.model = nn.Sequential(
88-
*self.block(in_channels, 64, normalize=False), # 3*256*256 -> 64*128*128
91+
*self.block(in_channels, 64, normalize=False), # 3*256*256 -> 64*128*128
8992
*self.block(64, 128), # 64*128*128 -> 128*64*64
90-
*self.block(128, 256), # 128*64*64 -> 256*32*32
91-
*self.block(256, 512), # 256*32*32 -> 512*16*16
92-
93-
nn.ZeroPad2d((1,0,1,0)), # padding left and top 512*16*16 -> 512*17*17
94-
nn.Conv2d(512, 1, 4, padding=1) # 512*17*17 -> 1*16*16
93+
*self.block(128, 256), # 128*64*64 -> 256*32*32
94+
*self.block(256, 512), # 256*32*32 -> 512*16*16
95+
96+
nn.ZeroPad2d((1, 0, 1, 0)), # padding left and top 512*16*16 -> 512*17*17
97+
nn.Conv2d(512, 1, 4, padding=1) # 512*17*17 -> 1*16*16
9598
)
96-
99+
97100
self.scale_factor = 16
98-
101+
99102
@staticmethod
100103
def block(in_channels, out_channels, normalize=True):
101104
layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1)]
102105
if normalize:
103106
layers.append(nn.InstanceNorm2d(out_channels))
104107
layers.append(nn.LeakyReLU(0.2, inplace=True))
105-
108+
106109
return layers
107-
110+
108111
def forward(self, x):
109112
return self.model(x)
110113

114+
111115
class SSIMLoss(nn.Module):
112116
def __init__(self, window_size=11, size_average=True):
113117
super(SSIMLoss, self).__init__()
114118
self.window_size = window_size
115119
self.size_average = size_average
116120
self.channel = 1
117121
self.create_window(window_size)
118-
122+
119123
def create_window(self, window_size, channel=1):
120124
# Create a Gaussian window (filter) with specified size and channel
121125
sigma = 1.5
@@ -131,14 +135,14 @@ def forward(self, img1, img2):
131135
if img1.size(1) != self.channel:
132136
self.channel = img1.size(1)
133137
self.create_window(self.window_size, self.channel)
134-
138+
135139
# Move window to the same device as the images
136140
window = self.window.to(img1.device)
137-
141+
138142
# Compute SSIM components
139143
mu1 = F.conv2d(img1, window, padding=self.window_size//2, groups=self.channel)
140144
mu2 = F.conv2d(img2, window, padding=self.window_size//2, groups=self.channel)
141-
145+
142146
mu1_sq = mu1.pow(2)
143147
mu2_sq = mu2.pow(2)
144148
mu1_mu2 = mu1 * mu2
@@ -153,7 +157,7 @@ def forward(self, img1, img2):
153157

154158
# SSIM calculation
155159
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
156-
160+
157161
if self.size_average:
158162
return torch.clamp((1 - ssim_map.mean()) / 2, 0, 1) # SSIM loss as (1 - SSIM) / 2
159163
else:

spaghetti/cli_inference.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1-
'''
1+
"""
22
Entry point for the CLI inference of SPAGHETTI.
3-
'''
3+
"""
44
import os
55
import argparse
66
from spaghetti import inferences
77
from PIL import Image
88

9+
910
def inference(input, output, checkpoint):
10-
'''
11+
"""
1112
The inference function for the CLI inference
1213
args:
1314
input: str, the input image or directory to translate
1415
output: str, the output directory to save the translated image(s)
1516
checkpoint: str, the path to the checkpoint file
16-
'''
17+
"""
1718
# check if input is a directory
1819
if os.path.isdir(input):
1920
# get all images
@@ -34,6 +35,7 @@ def inference(input, output, checkpoint):
3435
processed_imgs = model.pre_processing(pil_imgs, transform="default")
3536
model.inference(processed_imgs, names, output)
3637

38+
3739
def main():
3840
parser = argparse.ArgumentParser(description="CLI for translating PCM images using SPAGHETTI")
3941
parser.add_argument("--input", '-i', type=str, help="The input image or directory to translate")
@@ -47,5 +49,6 @@ def main():
4749
inference(args.input, args.output, args.checkpoint)
4850
print("Inference Completed. Images saved to ", args.output)
4951

52+
5053
if __name__ == "__main__":
51-
main()
54+
main()

spaghetti/dataset.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,29 @@
1-
'''
1+
"""
22
Prepare the datasets for training and inference
3-
'''
3+
"""
44
import random
55
from PIL import Image
66
import os
77
from torch.utils.data import Dataset
88

9+
910
class TrainingDataset(Dataset):
10-
'''
11+
"""
1112
The dataset class for the SPAGHETTI model training
1213
args:
1314
path_1: list of strings, the paths to the images in domain 1
1415
path_2: list of strings, the paths to the images in domain 2
1516
transform_1: the transformation for domain 1 images
1617
transform_2: the transformation for domain 2 images
1718
num_sample: int, optional, the number of images to sample from each domain
18-
'''
19-
def __init__(self, path_1: list[str], path_2: list[str],
19+
"""
20+
def __init__(self, path_1: list[str], path_2: list[str],
2021
transform_1, transform_2, num_sample=None):
2122
random.seed(42)
2223
# domain 1
2324
domain1_paths = []
2425
for each in path_1:
25-
domain1_paths.extend([os.path.join(each, x) for x in os.listdir(each)
26+
domain1_paths.extend([os.path.join(each, x) for x in os.listdir(each)
2627
if x.endswith((".png", ".jpg", ".jpeg", ".tiff", ".tif"))])
2728
if num_sample:
2829
try:
@@ -31,7 +32,7 @@ def __init__(self, path_1: list[str], path_2: list[str],
3132
self.domain1_images = random.choices(domain1_paths, k=num_sample)
3233
else:
3334
self.domain1_images = domain1_paths
34-
35+
3536
# domain 2
3637
domain2_paths = []
3738
for each in path_2:
@@ -60,9 +61,9 @@ def __getitem__(self, index):
6061
domain1_img_path = self.domain1_images[index % self.domain1_len]
6162
domain2_img_path = self.domain2_images[index % self.domain2_len]
6263
domain1_img = Image.open(domain1_img_path).convert("RGB")
63-
domain2_img = Image.open(domain2_img_path).convert("RGB")
64+
domain2_img = Image.open(domain2_img_path).convert("RGB")
6465

6566
domain1_img = self.transform_1(domain1_img)
6667
domain2_img = self.transform_2(domain2_img)
6768

68-
return domain1_img, domain2_img
69+
return domain1_img, domain2_img

0 commit comments

Comments
 (0)