@@ -894,3 +894,160 @@ def update_detection_state(step_outputs=None) -> Dict[str, Any]:
894
894
state ['detection_masks' ] = tf .concat (detection_masks , axis = 0 )
895
895
896
896
return state
897
+
898
+
899
+ def update_segmentation_state (step_outputs = None ) -> Dict [str , Any ]:
900
+ """Updates segmentation state to optionally add input image and predictions."""
901
+ state = {}
902
+ if step_outputs :
903
+ state ['image' ] = tf .concat (step_outputs ['visualization' ][0 ], axis = 0 )
904
+ state ['logits' ] = tf .concat (
905
+ step_outputs ['visualization' ][1 ]['logits' ], axis = 0
906
+ )
907
+ return state
908
+
909
+
910
+ def visualize_segmentation_outputs (
911
+ logs ,
912
+ task_config ,
913
+ original_image_spatial_shape = None ,
914
+ true_image_shape = None ,
915
+ image_mean : Optional [Union [float , List [float ]]] = None ,
916
+ image_std : Optional [Union [float , List [float ]]] = None ,
917
+ key : str = 'image/validation_outputs' ,
918
+ ) -> Dict [str , Any ]:
919
+ """Visualizes the detection outputs.
920
+
921
+ It extracts images and predictions from logs and draws visualization on input
922
+ images. By default, it requires `detection_boxes`, `detection_classes` and
923
+ `detection_scores` in the prediction, and optionally accepts
924
+ `detection_keypoints` and `detection_masks`.
925
+
926
+ Args:
927
+ logs: A dictionaty of log that contains images and predictions.
928
+ task_config: A task config.
929
+ original_image_spatial_shape: A [N, 2] tensor containing the spatial size of
930
+ the original image.
931
+ true_image_shape: A [N, 3] tensor containing the spatial size of unpadded
932
+ original_image.
933
+ image_mean: An optional float or list of floats used as the mean pixel value
934
+ to normalize images.
935
+ image_std: An optional float or list of floats used as the std to normalize
936
+ images.
937
+ key: A string specifying the key of the returned dictionary.
938
+
939
+ Returns:
940
+ A dictionary of images with visualization drawn on it. Each key corresponds
941
+ to a 4D tensor with segments drawn on each image.
942
+ """
943
+ images = logs ['image' ]
944
+ masks = np .argmax (logs ['logits' ], axis = - 1 )
945
+ num_classes = task_config .model .num_classes
946
+
947
+ def _denormalize_images (images : tf .Tensor ) -> tf .Tensor :
948
+ if image_mean is None and image_std is None :
949
+ images *= tf .constant (
950
+ preprocess_ops .STDDEV_RGB , shape = [1 , 1 , 3 ], dtype = images .dtype
951
+ )
952
+ images += tf .constant (
953
+ preprocess_ops .MEAN_RGB , shape = [1 , 1 , 3 ], dtype = images .dtype
954
+ )
955
+ elif image_mean is not None and image_std is not None :
956
+ if isinstance (image_mean , float ) and isinstance (image_std , float ):
957
+ images = images * image_std + image_mean
958
+ elif isinstance (image_mean , list ) and isinstance (image_std , list ):
959
+ images *= tf .constant (image_std , shape = [1 , 1 , 3 ], dtype = images .dtype )
960
+ images += tf .constant (image_mean , shape = [1 , 1 , 3 ], dtype = images .dtype )
961
+ else :
962
+ raise ValueError (
963
+ '`image_mean` and `image_std` should be the same type.'
964
+ )
965
+ else :
966
+ raise ValueError (
967
+ 'Both `image_mean` and `image_std` should be set or None at the same '
968
+ 'time.'
969
+ )
970
+ return tf .cast (images , dtype = tf .uint8 )
971
+
972
+ images = tf .nest .map_structure (
973
+ tf .identity ,
974
+ tf .map_fn (
975
+ _denormalize_images ,
976
+ elems = images ,
977
+ fn_output_signature = tf .TensorSpec (
978
+ shape = images .shape .as_list ()[1 :], dtype = tf .uint8
979
+ ),
980
+ parallel_iterations = 32 ,
981
+ ),
982
+ )
983
+
984
+ if images .shape [3 ] > 3 :
985
+ images = images [:, :, :, 0 :3 ]
986
+ elif images .shape [3 ] == 1 :
987
+ images = tf .image .grayscale_to_rgb (images )
988
+ if true_image_shape is None :
989
+ true_shapes = tf .constant (- 1 , shape = [images .shape .as_list ()[0 ], 3 ])
990
+ else :
991
+ true_shapes = true_image_shape
992
+ if original_image_spatial_shape is None :
993
+ original_shapes = tf .constant (- 1 , shape = [images .shape .as_list ()[0 ], 2 ])
994
+ else :
995
+ original_shapes = original_image_spatial_shape
996
+
997
+ visualize_fn = functools .partial (_visualize_masks , num_classes = num_classes )
998
+ elems = [true_shapes , original_shapes , images , masks ]
999
+
1000
+ def draw_segments (image_and_segments ):
1001
+ """Draws boxes on image."""
1002
+ true_shape = image_and_segments [0 ]
1003
+ original_shape = image_and_segments [1 ]
1004
+ if true_image_shape is not None :
1005
+ image = shape_utils .pad_or_clip_nd (
1006
+ image_and_segments [2 ], [true_shape [0 ], true_shape [1 ], 3 ]
1007
+ )
1008
+ if original_image_spatial_shape is not None :
1009
+ image_and_segments [2 ] = _resize_original_image (image , original_shape )
1010
+
1011
+ image_with_boxes = tf .compat .v1 .py_func (
1012
+ visualize_fn , image_and_segments [2 :], tf .uint8
1013
+ )
1014
+ return image_with_boxes
1015
+
1016
+ images_with_segments = tf .map_fn (
1017
+ draw_segments , elems , dtype = tf .uint8 , back_prop = False
1018
+ )
1019
+
1020
+ outputs = {}
1021
+ for i , image in enumerate (images_with_segments ):
1022
+ outputs [key + f'/{ i } ' ] = image [None , ...]
1023
+
1024
+ return outputs
1025
+
1026
+
1027
+ def _visualize_masks (image , mask , num_classes , alpha = 0.4 ):
1028
+ """Visualizes semantic segmentation masks."""
1029
+ solid_color = np .repeat (
1030
+ np .expand_dims (np .zeros_like (mask ), axis = 2 ), 3 , axis = 2
1031
+ )
1032
+ for i in range (num_classes ):
1033
+ color = STANDARD_COLORS [i % len (STANDARD_COLORS )]
1034
+ rgb = ImageColor .getrgb (color )
1035
+ one_class_mask = np .where (mask == i , 1 , 0 )
1036
+ solid_color = solid_color + np .expand_dims (
1037
+ one_class_mask , axis = 2
1038
+ ) * np .reshape (list (rgb ), [1 , 1 , 3 ])
1039
+
1040
+ pil_image = Image .fromarray (image )
1041
+ pil_solid_color = (
1042
+ Image .fromarray (np .uint8 (solid_color ))
1043
+ .convert ('RGBA' )
1044
+ .resize (pil_image .size )
1045
+ )
1046
+ pil_mask = (
1047
+ Image .fromarray (np .uint8 (255.0 * alpha * np .ones_like (mask )))
1048
+ .convert ('L' )
1049
+ .resize (pil_image .size )
1050
+ )
1051
+ pil_image = Image .composite (pil_solid_color , pil_image , pil_mask )
1052
+ np .copyto (image , np .array (pil_image .convert ('RGB' )))
1053
+ return image
0 commit comments