Skip to content

Commit 721a6fa

Browse files
Internal change
PiperOrigin-RevId: 480986311
1 parent c60499b commit 721a6fa

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

official/projects/panoptic/dataloaders/panoptic_maskrcnn_input.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,24 @@ def __init__(
4848
panoptic_category_mask_key:
4949
tf.io.FixedLenFeature((), tf.string, default_value=''),
5050
panoptic_instance_mask_key:
51-
tf.io.FixedLenFeature((), tf.string, default_value='')})
51+
tf.io.FixedLenFeature((), tf.string, default_value='')
52+
})
5253
self._segmentation_keys_to_features = keys_to_features
5354

55+
def decode_segmentation_mask(self, parsed_tensors):
56+
segmentation_mask = tf.io.decode_image(
57+
parsed_tensors['image/segmentation/class/encoded'], channels=1)
58+
segmentation_mask.set_shape([None, None, 1])
59+
return segmentation_mask
60+
5461
def decode(self, serialized_example):
5562
decoded_tensors = super(TfExampleDecoder, self).decode(serialized_example)
5663
parsed_tensors = tf.io.parse_single_example(
5764
serialized_example, self._segmentation_keys_to_features)
58-
segmentation_mask = tf.io.decode_image(
59-
parsed_tensors['image/segmentation/class/encoded'],
60-
channels=1)
61-
segmentation_mask.set_shape([None, None, 1])
62-
decoded_tensors.update({'groundtruth_segmentation_mask': segmentation_mask})
65+
decoded_tensors.update({
66+
'groundtruth_segmentation_mask':
67+
self.decode_segmentation_mask(parsed_tensors)
68+
})
6369

6470
if self._include_panoptic_masks:
6571
category_mask = tf.io.decode_image(
@@ -221,18 +227,21 @@ def _parse_train_data(self, data):
221227
are supposed to be used in computing the segmentation loss while
222228
training.
223229
"""
230+
# (height, width, num_channels = 1)
231+
# All the operations below support num_channels >= 1.
224232
segmentation_mask = data['groundtruth_segmentation_mask']
225233

226234
# Flips image randomly during training.
227235
if self.aug_rand_hflip:
228236
masks = data['groundtruth_instance_masks']
237+
num_image_channels = data['image'].shape.as_list()[-1]
229238
image_mask = tf.concat([data['image'], segmentation_mask], axis=2)
230239

231240
image_mask, boxes, masks = preprocess_ops.random_horizontal_flip(
232241
image_mask, data['groundtruth_boxes'], masks)
233242

234-
segmentation_mask = image_mask[:, :, -1:]
235-
image = image_mask[:, :, :-1]
243+
image = image_mask[:, :, :num_image_channels]
244+
segmentation_mask = image_mask[:, :, num_image_channels:]
236245

237246
data['image'] = image
238247
data['groundtruth_boxes'] = boxes
@@ -244,21 +253,22 @@ def _parse_train_data(self, data):
244253
image_scale = image_info[2, :]
245254
offset = image_info[3, :]
246255

247-
segmentation_mask = tf.reshape(
248-
segmentation_mask, shape=[1, data['height'], data['width']])
256+
# (height, width, num_channels = 1)
249257
segmentation_mask = tf.cast(segmentation_mask, tf.float32)
250258

251259
# Pad label and make sure the padded region assigned to the ignore label.
252260
# The label is first offset by +1 and then padded with 0.
253261
segmentation_mask += 1
254-
segmentation_mask = tf.expand_dims(segmentation_mask, axis=3)
262+
# (1, height, width, num_channels = 1)
263+
segmentation_mask = tf.expand_dims(segmentation_mask, axis=0)
255264
segmentation_mask = preprocess_ops.resize_and_crop_masks(
256265
segmentation_mask, image_scale, self._output_size, offset)
257266
segmentation_mask -= 1
258267
segmentation_mask = tf.where(
259268
tf.equal(segmentation_mask, -1),
260269
self._segmentation_ignore_label * tf.ones_like(segmentation_mask),
261270
segmentation_mask)
271+
# (height, width, num_channels = 1)
262272
segmentation_mask = tf.squeeze(segmentation_mask, axis=0)
263273
segmentation_valid_mask = tf.not_equal(
264274
segmentation_mask, self._segmentation_ignore_label)
@@ -291,9 +301,13 @@ def _parse_eval_data(self, data):
291301
shape [height_l, width_l, 4] representing anchor boxes at each
292302
level.
293303
"""
304+
294305
def _process_mask(mask, ignore_label, image_info):
306+
# (height, width, num_channels = 1)
307+
# All the operations below support num_channels >= 1.
295308
mask = tf.cast(mask, dtype=tf.float32)
296-
mask = tf.reshape(mask, shape=[1, data['height'], data['width'], 1])
309+
# (1, height, width, num_channels = 1)
310+
mask = tf.expand_dims(mask, axis=0)
297311
mask += 1
298312

299313
if self._segmentation_resize_eval_groundtruth:
@@ -314,12 +328,14 @@ def _process_mask(mask, ignore_label, image_info):
314328
tf.equal(mask, -1),
315329
ignore_label * tf.ones_like(mask),
316330
mask)
331+
# (height, width, num_channels = 1)
317332
mask = tf.squeeze(mask, axis=0)
318333
return mask
319334

320335
image, labels = super(Parser, self)._parse_eval_data(data)
321336
image_info = labels['image_info']
322337

338+
# (height, width, num_channels = 1)
323339
segmentation_mask = _process_mask(
324340
data['groundtruth_segmentation_mask'],
325341
self._segmentation_ignore_label, image_info)

0 commit comments

Comments
 (0)