Skip to content

Commit 3cf6fbf

Browse files
committed
resolve crop-related issues
1 parent c6d05f6 commit 3cf6fbf

File tree

7 files changed

+30
-12
lines changed

7 files changed

+30
-12
lines changed

deepem/models/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@ def __init__(self, cropsz):
3939
def forward(self, x):
4040
if self.cropsz is not None:
4141
for k, v in x.items():
42-
x[k] = torch_utils.crop_center(v, self.cropsz)
42+
cropsz = [int(v.shape[i]*self.cropsz[i]) for i in [-3,-2,-1]]
43+
x[k] = torch_utils.crop_center(v, cropsz)
4344
return x

deepem/test/option.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ def parse(self):
124124
diff = np.array(opt.fov) - np.array(opt.outputsz)
125125
assert all(diff >= 0)
126126
if any(diff > 0):
127-
opt.cropsz = opt.outputsz
127+
# opt.cropsz = opt.outputsz
128+
opt.cropsz = [o/float(f) for f,o in zip(opt.fov,opt.outputsz)]
128129
else:
129130
opt.cropsz = None
130131

deepem/train/data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def __init__(self, opt, data, is_train=True, prob=None):
3232

3333
def __call__(self):
3434
sample = next(self.dataiter)
35-
sample = self.modifier(sample)
35+
if self.is_train:
36+
sample = self.modifier(sample)
3637
for k in sample:
3738
is_input = k in self.inputs
3839
sample[k].requires_grad_(is_input)

deepem/train/logger.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import datetime
44
from collections import OrderedDict
5+
import numpy as np
56

67
import torch
78
from torchvision.utils import make_grid
@@ -87,15 +88,22 @@ def flush(self):
8788
return ret
8889

8990
def log_images(self, phase, iter_num, preds, sample):
91+
# Peep output size
92+
key = sorted(self.out_spec)[0]
93+
cropsz = sample[key].shape[-3:]
94+
for k in sorted(self.out_spec):
95+
outsz = sample[k].shape[-3:]
96+
assert np.array_equal(outsz, cropsz)
97+
9098
# Inputs
9199
for k in sorted(self.in_spec):
92100
tag = '{}/images/{}'.format(phase, k)
93101
tensor = sample[k][0,...].cpu()
94-
tensor = torch_utils.crop_center(tensor, self.outputsz)
102+
tensor = torch_utils.crop_center(tensor, cropsz)
95103
self.log_image(tag, tensor, iter_num)
96104

97105
# Outputs
98-
for k in sorted(self.out_spec):
106+
for k in sorted(self.out_spec):
99107
if k == 'affinity':
100108
# Prediction
101109
tag = '{}/images/{}'.format(phase, k)

deepem/train/option.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def initialize(self):
107107
self.parser.add_argument('--glia', type=float, default=0) # Glia
108108
self.parser.add_argument('--glia_mask', action='store_true')
109109

110+
# Test training
111+
self.parser.add_argument('--test', action='store_true')
112+
110113
self.initialized = True
111114

112115
def parse(self):
@@ -119,6 +122,8 @@ def parse(self):
119122
opt.exp_dir = opt.exp_name
120123
else:
121124
opt.exp_dir = 'experiments/{}'.format(opt.exp_name)
125+
if opt.test:
126+
opt.exp_dir = 'test/' + opt.exp_dir
122127
opt.log_dir = os.path.join(opt.exp_dir, 'logs')
123128
opt.model_dir = os.path.join(opt.exp_dir, 'models')
124129

@@ -162,7 +167,8 @@ def parse(self):
162167
diff = np.array(opt.fov) - np.array(opt.outputsz)
163168
assert all(diff >= 0)
164169
if any(diff > 0):
165-
opt.cropsz = opt.outputsz
170+
# opt.cropsz = opt.outputsz
171+
opt.cropsz = [o/float(f) for f,o in zip(opt.fov,opt.outputsz)]
166172
else:
167173
opt.cropsz = None
168174

@@ -181,6 +187,13 @@ def parse(self):
181187
'glia': ('glia', 1),
182188
}
183189

190+
# Test training
191+
if opt.test:
192+
opt.eval_intv = 100
193+
opt.eval_iter = 10
194+
opt.avgs_intv = 10
195+
opt.imgs_intv = 100
196+
184197
for k, v in class_dict.items():
185198
loss_w = args[k]
186199
if loss_w > 0:

deepem/train/run.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ def train(opt):
3535
# Load training samples.
3636
sample = train_loader()
3737

38-
# Crop sample.
39-
40-
4138
# Optimizer step
4239
optimizer.zero_grad()
4340
losses, nmasks, preds = forward(model, sample, opt)

deepem/utils/torch_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ def crop_border(v, size):
3939

4040

4141
def crop_center(v, size):
42-
# TODO: hack
43-
if all([a <= b for a, b in zip(v.shape[-3:], size[-3:])]):
44-
return v
4542
assert all([a >= b for a, b in zip(v.shape[-3:], size[-3:])])
4643
z, y, x = size[-3:]
4744
sx = (v.shape[-1] - x) // 2

0 commit comments

Comments
 (0)