|
| 1 | +from augmentor import Augment |
| 2 | +from dataprovider3 import DataProvider, Dataset |
| 3 | + |
| 4 | + |
| 5 | +def get_spec(in_spec, out_spec): |
| 6 | + spec = dict() |
| 7 | + # Input spec |
| 8 | + for k, v in in_spec.items(): |
| 9 | + spec[k] = v[-3:] |
| 10 | + # Output spec |
| 11 | + for k, v in out_spec.items(): |
| 12 | + dim = tuple(v[-3:]) |
| 13 | + spec[k] = dim |
| 14 | + spec[k+'_mask'] = dim |
| 15 | + return spec |
| 16 | + |
| 17 | + |
| 18 | +class Sampler(object): |
| 19 | + def __init__(self, data, spec, is_train, aug=None, prob=None): |
| 20 | + self.is_train = is_train |
| 21 | + if 'long_range' in spec: |
| 22 | + self.long_range = True |
| 23 | + del spec['long_range'] |
| 24 | + del spec['long_range_mask'] |
| 25 | + else: |
| 26 | + self.long_range = False |
| 27 | + self.build(data, spec, aug, prob) |
| 28 | + |
| 29 | + def __call__(self): |
| 30 | + sample = self.dataprovider() |
| 31 | + return self.postprocess(sample) |
| 32 | + |
| 33 | + def postprocess(self, sample): |
| 34 | + assert 'affinity' in sample |
| 35 | + assert 'mitochondria' in sample |
| 36 | + |
| 37 | + # TODO: Copy or Ref? |
| 38 | + if self.long_range: |
| 39 | + sample['long_range'] = sample['affinity'] |
| 40 | + sample['long_range_mask'] = sample['affinity_mask'] |
| 41 | + |
| 42 | + sample = Augment.to_tensor(sample) |
| 43 | + return self.to_float32(sample) |
| 44 | + |
| 45 | + def to_float32(self, sample): |
| 46 | + for k, v in sample.items(): |
| 47 | + sample[k] = v.astype('float32') |
| 48 | + return sample |
| 49 | + |
| 50 | + def build(self, data, spec, aug, prob): |
| 51 | + dp = DataProvider(spec) |
| 52 | + keys = data.keys() |
| 53 | + for k in keys: |
| 54 | + dp.add_dataset(self.build_dataset(k, data[k])) |
| 55 | + dp.set_augment(aug) |
| 56 | + dp.set_imgs(['input']) |
| 57 | + dp.set_segs(['affinity']) |
| 58 | + if prob: |
| 59 | + dp.set_sampling_weights(p=[prob[k] for k in keys]) |
| 60 | + else: |
| 61 | + dp.set_sampling_weights(p=None) |
| 62 | + self.dataprovider = dp |
| 63 | + print(dp) |
| 64 | + |
| 65 | + def build_dataset(self, tag, data): |
| 66 | + img = data['img'] |
| 67 | + seg = data['seg'] |
| 68 | + mit = data['mit'] |
| 69 | + loc = data['loc'] |
| 70 | + msk = self.get_mask(data) |
| 71 | + |
| 72 | + # Create Dataset. |
| 73 | + dset = Dataset(tag=tag) |
| 74 | + dset.add_data(key='input', data=img) |
| 75 | + dset.add_data(key='affinity', data=seg) |
| 76 | + dset.add_mask(key='affinity_mask', data=msk, loc=loc) |
| 77 | + dset.add_data(key='mitochondria', data=mit) |
| 78 | + dset.add_mask(key='mitochondria_mask', data=msk) |
| 79 | + |
| 80 | + return dset |
| 81 | + |
| 82 | + def get_mask(self, data): |
| 83 | + key = 'msk_train' if self.is_train else 'msk_val' |
| 84 | + if key in data: |
| 85 | + return data[key] |
| 86 | + assert 'msk' in data |
| 87 | + return data['msk'] |
0 commit comments