Skip to content

Commit d68ec87

Browse files
committed
affinity & mitochondria smapler
1 parent 2cfbd28 commit d68ec87

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

deepem/data/sampler/aff_mit.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from augmentor import Augment
2+
from dataprovider3 import DataProvider, Dataset
3+
4+
5+
def get_spec(in_spec, out_spec):
6+
spec = dict()
7+
# Input spec
8+
for k, v in in_spec.items():
9+
spec[k] = v[-3:]
10+
# Output spec
11+
for k, v in out_spec.items():
12+
dim = tuple(v[-3:])
13+
spec[k] = dim
14+
spec[k+'_mask'] = dim
15+
return spec
16+
17+
18+
class Sampler(object):
19+
def __init__(self, data, spec, is_train, aug=None, prob=None):
20+
self.is_train = is_train
21+
if 'long_range' in spec:
22+
self.long_range = True
23+
del spec['long_range']
24+
del spec['long_range_mask']
25+
else:
26+
self.long_range = False
27+
self.build(data, spec, aug, prob)
28+
29+
def __call__(self):
30+
sample = self.dataprovider()
31+
return self.postprocess(sample)
32+
33+
def postprocess(self, sample):
34+
assert 'affinity' in sample
35+
assert 'mitochondria' in sample
36+
37+
# TODO: Copy or Ref?
38+
if self.long_range:
39+
sample['long_range'] = sample['affinity']
40+
sample['long_range_mask'] = sample['affinity_mask']
41+
42+
sample = Augment.to_tensor(sample)
43+
return self.to_float32(sample)
44+
45+
def to_float32(self, sample):
46+
for k, v in sample.items():
47+
sample[k] = v.astype('float32')
48+
return sample
49+
50+
def build(self, data, spec, aug, prob):
51+
dp = DataProvider(spec)
52+
keys = data.keys()
53+
for k in keys:
54+
dp.add_dataset(self.build_dataset(k, data[k]))
55+
dp.set_augment(aug)
56+
dp.set_imgs(['input'])
57+
dp.set_segs(['affinity'])
58+
if prob:
59+
dp.set_sampling_weights(p=[prob[k] for k in keys])
60+
else:
61+
dp.set_sampling_weights(p=None)
62+
self.dataprovider = dp
63+
print(dp)
64+
65+
def build_dataset(self, tag, data):
66+
img = data['img']
67+
seg = data['seg']
68+
mit = data['mit']
69+
loc = data['loc']
70+
msk = self.get_mask(data)
71+
72+
# Create Dataset.
73+
dset = Dataset(tag=tag)
74+
dset.add_data(key='input', data=img)
75+
dset.add_data(key='affinity', data=seg)
76+
dset.add_mask(key='affinity_mask', data=msk, loc=loc)
77+
dset.add_data(key='mitochondria', data=mit)
78+
dset.add_mask(key='mitochondria_mask', data=msk)
79+
80+
return dset
81+
82+
def get_mask(self, data):
83+
key = 'msk_train' if self.is_train else 'msk_val'
84+
if key in data:
85+
return data[key]
86+
assert 'msk' in data
87+
return data['msk']

0 commit comments

Comments
 (0)