Skip to content

Commit 0182948

Browse files
fyangftensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 519765233
1 parent 67f4fc6 commit 0182948

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

official/vision/tasks/retinanet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,6 @@ def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
456456
validation_outputs = visualization_utils.visualize_outputs(
457457
logs=aggregated_logs, task_config=self.task_config
458458
)
459-
logs.update({'image/validation_outputs': validation_outputs})
459+
logs.update(validation_outputs)
460460

461461
return logs

official/vision/utils/object_detection/visualization_utils.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020
import collections
2121
import functools
22-
from typing import Any, Dict
22+
from typing import Any, Dict, Optional, List, Union
2323

2424
from absl import logging
2525
# Set headless-friendly backend.
@@ -354,7 +354,10 @@ def visualize_outputs(
354354
max_boxes_to_draw=20,
355355
min_score_thresh=0.2,
356356
use_normalized_coordinates=False,
357-
):
357+
image_mean: Optional[Union[float, List[float]]] = None,
358+
image_std: Optional[Union[float, List[float]]] = None,
359+
key: str = 'image/validation_outputs',
360+
) -> Dict[str, Any]:
358361
"""Visualizes the detection outputs.
359362
360363
It extracts images and predictions from logs and draws visualization on input
@@ -375,11 +378,17 @@ def visualize_outputs(
375378
0.2.
376379
use_normalized_coordinates: Whether to assume boxes and kepoints are in
377380
normalized coordinates (as opposed to absolute coordiantes). Default is
378-
True.
381+
False.
382+
image_mean: An optional float or list of floats used as the mean pixel value
383+
to normalize images.
384+
image_std: An optional float or list of floats used as the std to normalize
385+
images.
386+
key: A string specifying the key of the returned dictionary.
379387
380388
Returns:
381-
A 4D tensor with predictions (boxes, segments and/or keypoints) drawn on
382-
each image.
389+
A dictionary of images with visualization drawn on it. Each key corresponds
390+
to a 4D tensor with predictions (boxes, segments and/or keypoints) drawn
391+
on each image.
383392
"""
384393
images = logs['image']
385394
boxes = logs['detection_boxes']
@@ -399,12 +408,28 @@ def visualize_outputs(
399408
category_index[i] = {'id': i, 'name': str(i)}
400409

401410
def _denormalize_images(images: tf.Tensor) -> tf.Tensor:
402-
images *= tf.constant(
403-
preprocess_ops.STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype
404-
)
405-
images += tf.constant(
406-
preprocess_ops.MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype
407-
)
411+
if image_mean is None and image_std is None:
412+
images *= tf.constant(
413+
preprocess_ops.STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype
414+
)
415+
images += tf.constant(
416+
preprocess_ops.MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype
417+
)
418+
elif image_mean is not None and image_std is not None:
419+
if isinstance(image_mean, float) and isinstance(image_std, float):
420+
images = images * image_std + image_mean
421+
elif isinstance(image_mean, list) and isinstance(image_std, list):
422+
images *= tf.constant(image_std, shape=[1, 1, 3], dtype=images.dtype)
423+
images += tf.constant(image_mean, shape=[1, 1, 3], dtype=images.dtype)
424+
else:
425+
raise ValueError(
426+
'`image_mean` and `image_std` should be the same type.'
427+
)
428+
else:
429+
raise ValueError(
430+
'Both `image_mean` and `image_std` should be set or None at the same '
431+
'time.'
432+
)
408433
return tf.cast(images, dtype=tf.uint8)
409434

410435
images = tf.nest.map_structure(
@@ -419,7 +444,7 @@ def _denormalize_images(images: tf.Tensor) -> tf.Tensor:
419444
),
420445
)
421446

422-
return draw_bounding_boxes_on_image_tensors(
447+
images_with_boxes = draw_bounding_boxes_on_image_tensors(
423448
images,
424449
boxes,
425450
classes,
@@ -434,6 +459,12 @@ def _denormalize_images(images: tf.Tensor) -> tf.Tensor:
434459
use_normalized_coordinates,
435460
)
436461

462+
outputs = {}
463+
for i, image in enumerate(images_with_boxes):
464+
outputs[key + f'/{i}'] = image[None, ...]
465+
466+
return outputs
467+
437468

438469
def draw_bounding_boxes_on_image_tensors(images,
439470
boxes,

0 commit comments

Comments
 (0)