Skip to content

Commit 5025dc3

Browse files
fyangftensorflower-gardener
authored andcommitted
No public description
PiperOrigin-RevId: 572370412
1 parent e5239d5 commit 5025dc3

File tree

3 files changed

+192
-5
lines changed

3 files changed

+192
-5
lines changed

official/vision/configs/semantic_segmentation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
198198
init_checkpoint_modules: Union[
199199
str, List[str]] = 'all' # all, backbone, and/or decoder
200200
export_config: ExportConfig = dataclasses.field(default_factory=ExportConfig)
201+
allow_image_summary: bool = True
201202

202203

203204
@exp_factory.register_config_factory('semantic_segmentation')

official/vision/tasks/semantic_segmentation.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from official.vision.evaluation import segmentation_metrics
3030
from official.vision.losses import segmentation_losses
3131
from official.vision.modeling import factory
32+
from official.vision.utils.object_detection import visualization_utils
3233

3334

3435
@task_factory.register_task_cls(exp_cfg.SemanticSegmentationTask)
@@ -321,6 +322,14 @@ def validation_step(self,
321322
if metrics:
322323
self.process_metrics(metrics, labels, outputs)
323324

325+
if (
326+
hasattr(self.task_config, 'allow_image_summary')
327+
and self.task_config.allow_image_summary
328+
):
329+
logs.update(
330+
{'visualization': (tf.cast(features, dtype=tf.float32), outputs)}
331+
)
332+
324333
return logs
325334

326335
def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model):
@@ -330,17 +339,37 @@ def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model):
330339
def aggregate_logs(self, state=None, step_outputs=None):
331340
if state is None and self.iou_metric is not None:
332341
self.iou_metric.reset_states()
333-
state = self.iou_metric
342+
343+
if 'visualization' in step_outputs:
344+
# Update segmentation state for writing summary if there are artifacts for
345+
# visualization.
346+
if state is None:
347+
state = {}
348+
state.update(visualization_utils.update_segmentation_state(step_outputs))
349+
350+
if state is None:
351+
# Create an arbitrary state to indicate it's not the first step in the
352+
# following calls to this function.
353+
state = True
354+
334355
return state
335356

336357
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
337-
result = {}
358+
logs = {}
338359
if self.iou_metric is not None:
339360
ious = self.iou_metric.result()
340361
# TODO(arashwan): support loading class name from a label map file.
341362
if self.task_config.evaluation.report_per_class_iou:
342363
for i, value in enumerate(ious.numpy()):
343-
result.update({'iou/{}'.format(i): value})
364+
logs.update({'iou/{}'.format(i): value})
344365
# Computes mean IoU
345-
result.update({'mean_iou': tf.reduce_mean(ious)})
346-
return result
366+
logs.update({'mean_iou': tf.reduce_mean(ious)})
367+
368+
# Add visualization for summary.
369+
if isinstance(aggregated_logs, dict) and 'image' in aggregated_logs:
370+
validation_outputs = visualization_utils.visualize_segmentation_outputs(
371+
logs=aggregated_logs, task_config=self.task_config
372+
)
373+
logs.update(validation_outputs)
374+
375+
return logs

official/vision/utils/object_detection/visualization_utils.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,3 +894,160 @@ def update_detection_state(step_outputs=None) -> Dict[str, Any]:
894894
state['detection_masks'] = tf.concat(detection_masks, axis=0)
895895

896896
return state
897+
898+
899+
def update_segmentation_state(step_outputs=None) -> Dict[str, Any]:
900+
"""Updates segmentation state to optionally add input image and predictions."""
901+
state = {}
902+
if step_outputs:
903+
state['image'] = tf.concat(step_outputs['visualization'][0], axis=0)
904+
state['logits'] = tf.concat(
905+
step_outputs['visualization'][1]['logits'], axis=0
906+
)
907+
return state
908+
909+
910+
def visualize_segmentation_outputs(
911+
logs,
912+
task_config,
913+
original_image_spatial_shape=None,
914+
true_image_shape=None,
915+
image_mean: Optional[Union[float, List[float]]] = None,
916+
image_std: Optional[Union[float, List[float]]] = None,
917+
key: str = 'image/validation_outputs',
918+
) -> Dict[str, Any]:
919+
"""Visualizes the detection outputs.
920+
921+
It extracts images and predictions from logs and draws visualization on input
922+
images. By default, it requires `detection_boxes`, `detection_classes` and
923+
`detection_scores` in the prediction, and optionally accepts
924+
`detection_keypoints` and `detection_masks`.
925+
926+
Args:
927+
logs: A dictionaty of log that contains images and predictions.
928+
task_config: A task config.
929+
original_image_spatial_shape: A [N, 2] tensor containing the spatial size of
930+
the original image.
931+
true_image_shape: A [N, 3] tensor containing the spatial size of unpadded
932+
original_image.
933+
image_mean: An optional float or list of floats used as the mean pixel value
934+
to normalize images.
935+
image_std: An optional float or list of floats used as the std to normalize
936+
images.
937+
key: A string specifying the key of the returned dictionary.
938+
939+
Returns:
940+
A dictionary of images with visualization drawn on it. Each key corresponds
941+
to a 4D tensor with segments drawn on each image.
942+
"""
943+
images = logs['image']
944+
masks = np.argmax(logs['logits'], axis=-1)
945+
num_classes = task_config.model.num_classes
946+
947+
def _denormalize_images(images: tf.Tensor) -> tf.Tensor:
948+
if image_mean is None and image_std is None:
949+
images *= tf.constant(
950+
preprocess_ops.STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype
951+
)
952+
images += tf.constant(
953+
preprocess_ops.MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype
954+
)
955+
elif image_mean is not None and image_std is not None:
956+
if isinstance(image_mean, float) and isinstance(image_std, float):
957+
images = images * image_std + image_mean
958+
elif isinstance(image_mean, list) and isinstance(image_std, list):
959+
images *= tf.constant(image_std, shape=[1, 1, 3], dtype=images.dtype)
960+
images += tf.constant(image_mean, shape=[1, 1, 3], dtype=images.dtype)
961+
else:
962+
raise ValueError(
963+
'`image_mean` and `image_std` should be the same type.'
964+
)
965+
else:
966+
raise ValueError(
967+
'Both `image_mean` and `image_std` should be set or None at the same '
968+
'time.'
969+
)
970+
return tf.cast(images, dtype=tf.uint8)
971+
972+
images = tf.nest.map_structure(
973+
tf.identity,
974+
tf.map_fn(
975+
_denormalize_images,
976+
elems=images,
977+
fn_output_signature=tf.TensorSpec(
978+
shape=images.shape.as_list()[1:], dtype=tf.uint8
979+
),
980+
parallel_iterations=32,
981+
),
982+
)
983+
984+
if images.shape[3] > 3:
985+
images = images[:, :, :, 0:3]
986+
elif images.shape[3] == 1:
987+
images = tf.image.grayscale_to_rgb(images)
988+
if true_image_shape is None:
989+
true_shapes = tf.constant(-1, shape=[images.shape.as_list()[0], 3])
990+
else:
991+
true_shapes = true_image_shape
992+
if original_image_spatial_shape is None:
993+
original_shapes = tf.constant(-1, shape=[images.shape.as_list()[0], 2])
994+
else:
995+
original_shapes = original_image_spatial_shape
996+
997+
visualize_fn = functools.partial(_visualize_masks, num_classes=num_classes)
998+
elems = [true_shapes, original_shapes, images, masks]
999+
1000+
def draw_segments(image_and_segments):
1001+
"""Draws boxes on image."""
1002+
true_shape = image_and_segments[0]
1003+
original_shape = image_and_segments[1]
1004+
if true_image_shape is not None:
1005+
image = shape_utils.pad_or_clip_nd(
1006+
image_and_segments[2], [true_shape[0], true_shape[1], 3]
1007+
)
1008+
if original_image_spatial_shape is not None:
1009+
image_and_segments[2] = _resize_original_image(image, original_shape)
1010+
1011+
image_with_boxes = tf.compat.v1.py_func(
1012+
visualize_fn, image_and_segments[2:], tf.uint8
1013+
)
1014+
return image_with_boxes
1015+
1016+
images_with_segments = tf.map_fn(
1017+
draw_segments, elems, dtype=tf.uint8, back_prop=False
1018+
)
1019+
1020+
outputs = {}
1021+
for i, image in enumerate(images_with_segments):
1022+
outputs[key + f'/{i}'] = image[None, ...]
1023+
1024+
return outputs
1025+
1026+
1027+
def _visualize_masks(image, mask, num_classes, alpha=0.4):
1028+
"""Visualizes semantic segmentation masks."""
1029+
solid_color = np.repeat(
1030+
np.expand_dims(np.zeros_like(mask), axis=2), 3, axis=2
1031+
)
1032+
for i in range(num_classes):
1033+
color = STANDARD_COLORS[i % len(STANDARD_COLORS)]
1034+
rgb = ImageColor.getrgb(color)
1035+
one_class_mask = np.where(mask == i, 1, 0)
1036+
solid_color = solid_color + np.expand_dims(
1037+
one_class_mask, axis=2
1038+
) * np.reshape(list(rgb), [1, 1, 3])
1039+
1040+
pil_image = Image.fromarray(image)
1041+
pil_solid_color = (
1042+
Image.fromarray(np.uint8(solid_color))
1043+
.convert('RGBA')
1044+
.resize(pil_image.size)
1045+
)
1046+
pil_mask = (
1047+
Image.fromarray(np.uint8(255.0 * alpha * np.ones_like(mask)))
1048+
.convert('L')
1049+
.resize(pil_image.size)
1050+
)
1051+
pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
1052+
np.copyto(image, np.array(pil_image.convert('RGB')))
1053+
return image

0 commit comments

Comments
 (0)