19
19
"""
20
20
import collections
21
21
import functools
22
- from typing import Any , Dict
22
+ from typing import Any , Dict , Optional , List , Union
23
23
24
24
from absl import logging
25
25
# Set headless-friendly backend.
@@ -354,7 +354,10 @@ def visualize_outputs(
354
354
max_boxes_to_draw = 20 ,
355
355
min_score_thresh = 0.2 ,
356
356
use_normalized_coordinates = False ,
357
- ):
357
+ image_mean : Optional [Union [float , List [float ]]] = None ,
358
+ image_std : Optional [Union [float , List [float ]]] = None ,
359
+ key : str = 'image/validation_outputs' ,
360
+ ) -> Dict [str , Any ]:
358
361
"""Visualizes the detection outputs.
359
362
360
363
It extracts images and predictions from logs and draws visualization on input
@@ -375,11 +378,17 @@ def visualize_outputs(
375
378
0.2.
376
379
use_normalized_coordinates: Whether to assume boxes and kepoints are in
377
380
normalized coordinates (as opposed to absolute coordiantes). Default is
378
- True.
381
+ False.
382
+ image_mean: An optional float or list of floats used as the mean pixel value
383
+ to normalize images.
384
+ image_std: An optional float or list of floats used as the std to normalize
385
+ images.
386
+ key: A string specifying the key of the returned dictionary.
379
387
380
388
Returns:
381
- A 4D tensor with predictions (boxes, segments and/or keypoints) drawn on
382
- each image.
389
+ A dictionary of images with visualization drawn on it. Each key corresponds
390
+ to a 4D tensor with predictions (boxes, segments and/or keypoints) drawn
391
+ on each image.
383
392
"""
384
393
images = logs ['image' ]
385
394
boxes = logs ['detection_boxes' ]
@@ -399,12 +408,28 @@ def visualize_outputs(
399
408
category_index [i ] = {'id' : i , 'name' : str (i )}
400
409
401
410
def _denormalize_images (images : tf .Tensor ) -> tf .Tensor :
402
- images *= tf .constant (
403
- preprocess_ops .STDDEV_RGB , shape = [1 , 1 , 3 ], dtype = images .dtype
404
- )
405
- images += tf .constant (
406
- preprocess_ops .MEAN_RGB , shape = [1 , 1 , 3 ], dtype = images .dtype
407
- )
411
+ if image_mean is None and image_std is None :
412
+ images *= tf .constant (
413
+ preprocess_ops .STDDEV_RGB , shape = [1 , 1 , 3 ], dtype = images .dtype
414
+ )
415
+ images += tf .constant (
416
+ preprocess_ops .MEAN_RGB , shape = [1 , 1 , 3 ], dtype = images .dtype
417
+ )
418
+ elif image_mean is not None and image_std is not None :
419
+ if isinstance (image_mean , float ) and isinstance (image_std , float ):
420
+ images = images * image_std + image_mean
421
+ elif isinstance (image_mean , list ) and isinstance (image_std , list ):
422
+ images *= tf .constant (image_std , shape = [1 , 1 , 3 ], dtype = images .dtype )
423
+ images += tf .constant (image_mean , shape = [1 , 1 , 3 ], dtype = images .dtype )
424
+ else :
425
+ raise ValueError (
426
+ '`image_mean` and `image_std` should be the same type.'
427
+ )
428
+ else :
429
+ raise ValueError (
430
+ 'Both `image_mean` and `image_std` should be set or None at the same '
431
+ 'time.'
432
+ )
408
433
return tf .cast (images , dtype = tf .uint8 )
409
434
410
435
images = tf .nest .map_structure (
@@ -419,7 +444,7 @@ def _denormalize_images(images: tf.Tensor) -> tf.Tensor:
419
444
),
420
445
)
421
446
422
- return draw_bounding_boxes_on_image_tensors (
447
+ images_with_boxes = draw_bounding_boxes_on_image_tensors (
423
448
images ,
424
449
boxes ,
425
450
classes ,
@@ -434,6 +459,12 @@ def _denormalize_images(images: tf.Tensor) -> tf.Tensor:
434
459
use_normalized_coordinates ,
435
460
)
436
461
462
+ outputs = {}
463
+ for i , image in enumerate (images_with_boxes ):
464
+ outputs [key + f'/{ i } ' ] = image [None , ...]
465
+
466
+ return outputs
467
+
437
468
438
469
def draw_bounding_boxes_on_image_tensors (images ,
439
470
boxes ,
0 commit comments