@@ -48,18 +48,24 @@ def __init__(
48
48
panoptic_category_mask_key :
49
49
tf .io .FixedLenFeature ((), tf .string , default_value = '' ),
50
50
panoptic_instance_mask_key :
51
- tf .io .FixedLenFeature ((), tf .string , default_value = '' )})
51
+ tf .io .FixedLenFeature ((), tf .string , default_value = '' )
52
+ })
52
53
self ._segmentation_keys_to_features = keys_to_features
53
54
55
+ def decode_segmentation_mask (self , parsed_tensors ):
56
+ segmentation_mask = tf .io .decode_image (
57
+ parsed_tensors ['image/segmentation/class/encoded' ], channels = 1 )
58
+ segmentation_mask .set_shape ([None , None , 1 ])
59
+ return segmentation_mask
60
+
54
61
def decode (self , serialized_example ):
55
62
decoded_tensors = super (TfExampleDecoder , self ).decode (serialized_example )
56
63
parsed_tensors = tf .io .parse_single_example (
57
64
serialized_example , self ._segmentation_keys_to_features )
58
- segmentation_mask = tf .io .decode_image (
59
- parsed_tensors ['image/segmentation/class/encoded' ],
60
- channels = 1 )
61
- segmentation_mask .set_shape ([None , None , 1 ])
62
- decoded_tensors .update ({'groundtruth_segmentation_mask' : segmentation_mask })
65
+ decoded_tensors .update ({
66
+ 'groundtruth_segmentation_mask' :
67
+ self .decode_segmentation_mask (parsed_tensors )
68
+ })
63
69
64
70
if self ._include_panoptic_masks :
65
71
category_mask = tf .io .decode_image (
@@ -221,18 +227,21 @@ def _parse_train_data(self, data):
221
227
are supposed to be used in computing the segmentation loss while
222
228
training.
223
229
"""
230
+ # (height, width, num_channels = 1)
231
+ # All the operations below support num_channels >= 1.
224
232
segmentation_mask = data ['groundtruth_segmentation_mask' ]
225
233
226
234
# Flips image randomly during training.
227
235
if self .aug_rand_hflip :
228
236
masks = data ['groundtruth_instance_masks' ]
237
+ num_image_channels = data ['image' ].shape .as_list ()[- 1 ]
229
238
image_mask = tf .concat ([data ['image' ], segmentation_mask ], axis = 2 )
230
239
231
240
image_mask , boxes , masks = preprocess_ops .random_horizontal_flip (
232
241
image_mask , data ['groundtruth_boxes' ], masks )
233
242
234
- segmentation_mask = image_mask [:, :, - 1 : ]
235
- image = image_mask [:, :, : - 1 ]
243
+ image = image_mask [:, :, : num_image_channels ]
244
+ segmentation_mask = image_mask [:, :, num_image_channels : ]
236
245
237
246
data ['image' ] = image
238
247
data ['groundtruth_boxes' ] = boxes
@@ -244,21 +253,22 @@ def _parse_train_data(self, data):
244
253
image_scale = image_info [2 , :]
245
254
offset = image_info [3 , :]
246
255
247
- segmentation_mask = tf .reshape (
248
- segmentation_mask , shape = [1 , data ['height' ], data ['width' ]])
256
+ # (height, width, num_channels = 1)
249
257
segmentation_mask = tf .cast (segmentation_mask , tf .float32 )
250
258
251
259
# Pad label and make sure the padded region assigned to the ignore label.
252
260
# The label is first offset by +1 and then padded with 0.
253
261
segmentation_mask += 1
254
- segmentation_mask = tf .expand_dims (segmentation_mask , axis = 3 )
262
+ # (1, height, width, num_channels = 1)
263
+ segmentation_mask = tf .expand_dims (segmentation_mask , axis = 0 )
255
264
segmentation_mask = preprocess_ops .resize_and_crop_masks (
256
265
segmentation_mask , image_scale , self ._output_size , offset )
257
266
segmentation_mask -= 1
258
267
segmentation_mask = tf .where (
259
268
tf .equal (segmentation_mask , - 1 ),
260
269
self ._segmentation_ignore_label * tf .ones_like (segmentation_mask ),
261
270
segmentation_mask )
271
+ # (height, width, num_channels = 1)
262
272
segmentation_mask = tf .squeeze (segmentation_mask , axis = 0 )
263
273
segmentation_valid_mask = tf .not_equal (
264
274
segmentation_mask , self ._segmentation_ignore_label )
@@ -291,9 +301,13 @@ def _parse_eval_data(self, data):
291
301
shape [height_l, width_l, 4] representing anchor boxes at each
292
302
level.
293
303
"""
304
+
294
305
def _process_mask (mask , ignore_label , image_info ):
306
+ # (height, width, num_channels = 1)
307
+ # All the operations below support num_channels >= 1.
295
308
mask = tf .cast (mask , dtype = tf .float32 )
296
- mask = tf .reshape (mask , shape = [1 , data ['height' ], data ['width' ], 1 ])
309
+ # (1, height, width, num_channels = 1)
310
+ mask = tf .expand_dims (mask , axis = 0 )
297
311
mask += 1
298
312
299
313
if self ._segmentation_resize_eval_groundtruth :
@@ -314,12 +328,14 @@ def _process_mask(mask, ignore_label, image_info):
314
328
tf .equal (mask , - 1 ),
315
329
ignore_label * tf .ones_like (mask ),
316
330
mask )
331
+ # (height, width, num_channels = 1)
317
332
mask = tf .squeeze (mask , axis = 0 )
318
333
return mask
319
334
320
335
image , labels = super (Parser , self )._parse_eval_data (data )
321
336
image_info = labels ['image_info' ]
322
337
338
+ # (height, width, num_channels = 1)
323
339
segmentation_mask = _process_mask (
324
340
data ['groundtruth_segmentation_mask' ],
325
341
self ._segmentation_ignore_label , image_info )
0 commit comments