Skip to content

Commit 2875302

Browse files
committed
resolve crop-related issues
1 parent cdb70bd commit 2875302

File tree

5 files changed

+27
-12
lines changed

5 files changed

+27
-12
lines changed

deepem/models/rsunet_deprecated.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import emvision
55

6-
from deepem.models.layers import Conv
6+
from deepem.models.layers import Conv, Crop
77

88

99
def create_model(opt):
@@ -19,7 +19,7 @@ def create_model(opt):
1919
else:
2020
# Batch (instance) normalization
2121
core = emvision.models.RSUNet(width=width[:depth])
22-
return Model(core, opt.in_spec, opt.out_spec, width[0])
22+
return Model(core, opt.in_spec, opt.out_spec, width[0], cropsz=opt.cropsz)
2323

2424

2525
class InputBlock(nn.Sequential):
@@ -45,7 +45,7 @@ class Model(nn.Sequential):
4545
"""
4646
Residual Symmetric U-Net.
4747
"""
48-
def __init__(self, core, in_spec, out_spec, out_channels):
48+
def __init__(self, core, in_spec, out_spec, out_channels, cropsz=None):
4949
super(Model, self).__init__()
5050

5151
assert len(in_spec)==1, "model takes a single input"
@@ -55,3 +55,5 @@ def __init__(self, core, in_spec, out_spec, out_channels):
5555
self.add_module('in', InputBlock(in_channels, out_channels, io_kernel))
5656
self.add_module('core', core)
5757
self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel))
58+
if cropsz is not None:
59+
self.add_module('crop', Crop(cropsz))

deepem/models/updown_deprecated.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import emvision
55
from emvision.models.layers import BilinearUp
6-
from deepem.models.layers import Conv
6+
from deepem.models.layers import Conv, Crop
77

88

99
def create_model(opt):
@@ -19,7 +19,7 @@ def create_model(opt):
1919
else:
2020
# Batch (instance) normalization
2121
core = emvision.models.RSUNet(width=width[:depth])
22-
return Model(core, opt.in_spec, opt.out_spec, width[0])
22+
return Model(core, opt.in_spec, opt.out_spec, width[0], cropsz=opt.cropsz)
2323

2424

2525
class InputBlock(nn.Sequential):
@@ -47,7 +47,7 @@ class Model(nn.Sequential):
4747
"""
4848
Residual Symmetric U-Net with down/upsampling in/output.
4949
"""
50-
def __init__(self, core, in_spec, out_spec, out_channels):
50+
def __init__(self, core, in_spec, out_spec, out_channels, cropsz=None):
5151
super(Model, self).__init__()
5252

5353
assert len(in_spec)==1, "model takes a single input"
@@ -57,3 +57,5 @@ def __init__(self, core, in_spec, out_spec, out_channels):
5757
self.add_module('in', InputBlock(in_channels, out_channels, io_kernel))
5858
self.add_module('core', core)
5959
self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel))
60+
if cropsz is not None:
61+
self.add_module('crop', Crop(cropsz))

deepem/test/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, model, opt):
1818
self.in_spec = dict(opt.in_spec)
1919
self.scan_spec = dict(opt.scan_spec)
2020
self.pretrain = opt.pretrain
21-
self.cropsz = opt.cropsz
21+
self.force_crop = opt.force_crop
2222

2323
# Softer softmax
2424
if opt.temperature is None:
@@ -62,8 +62,8 @@ def forward(self, sample):
6262
outputs[k] *= self.mask[k]
6363

6464
# Crop outputs.
65-
if self.cropsz is not None:
66-
outputs[k] = torch_utils.crop_border(outputs[k], self.cropsz)
65+
if self.force_crop is not None:
66+
outputs[k] = torch_utils.crop_border(outputs[k], self.force_crop)
6767

6868
return outputs
6969

deepem/test/option.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import math
44
import os
5+
import numpy as np
56

67
from deepem.utils.py_utils import vec3, vec3f
78

@@ -28,7 +29,7 @@ def initialize(self):
2829
self.parser.add_argument('--no_eval', action='store_true')
2930
self.parser.add_argument('--inputsz', type=vec3, default=None)
3031
self.parser.add_argument('--outputsz', type=vec3, default=None)
31-
self.parser.add_argument('--cropsz', type=vec3, default=None)
32+
self.parser.add_argument('--force_crop', type=vec3, default=None)
3233
self.parser.add_argument('--fov', type=vec3, default=None)
3334
self.parser.add_argument('--depth', type=int, default=4)
3435
self.parser.add_argument('--width', type=int, default=None, nargs='+')
@@ -115,9 +116,17 @@ def parse(self):
115116
opt.fov = tuple(opt.fov)
116117
opt.inputsz = opt.fov if opt.inputsz is None else opt.inputsz
117118
opt.outputsz = opt.fov if opt.outputsz is None else opt.outputsz
118-
opt.cropsz = opt.cropsz
119119
opt.in_spec = dict(input=(1,) + opt.inputsz)
120120
opt.out_spec = dict()
121+
122+
# Crop output
123+
diff = np.array(opt.fov) - np.array(opt.outputsz)
124+
assert all(diff >= 0)
125+
if any(diff > 0):
126+
opt.cropsz = opt.outputsz
127+
else:
128+
opt.cropsz = None
129+
121130
if opt.aff:
122131
opt.out_spec['affinity'] = (3,) + opt.outputsz
123132
if opt.bdr:
@@ -154,6 +163,8 @@ def parse(self):
154163
opt.scan_spec['blood_vessel'] = (opt.blv_num_channels,) + opt.outputsz
155164
if opt.glia:
156165
opt.scan_spec['glia'] = (1,) + opt.outputsz
166+
167+
# Overlap & stride
157168
opt.overlap = self.get_overlap(opt.outputsz, opt.overlap)
158169
opt.stride = tuple(int(f-o) for f,o in zip(opt.outputsz, opt.overlap))
159170
opt.scan_params = dict(stride=opt.stride, blend=opt.blend)

deepem/train/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def build(self, opt, data, is_train, prob):
5959
self.modifier = lambda x: x
6060
if opt.modifier is not None:
6161
mod = imp.load_source('modifier', opt.modifier)
62-
self.modifier = mod.Modifier()
62+
self.modifier = mod.Modifier(**opt.mod_params)
6363

6464
# Data loader
6565
size = (opt.max_iter - opt.chkpt_num) * opt.batch_size

0 commit comments

Comments
 (0)