@@ -853,6 +853,7 @@ def __init__(self,
853
853
use_cpu_nms : bool = False ,
854
854
soft_nms_sigma : Optional [float ] = None ,
855
855
tflite_post_processing_config : Optional [Dict [str , Any ]] = None ,
856
+ pre_nms_top_k_sharding_block : Optional [int ] = None ,
856
857
** kwargs ):
857
858
"""Initializes a multi-level detection generator.
858
859
@@ -873,6 +874,9 @@ def __init__(self,
873
874
When soft_nms_sigma=0.0, we fall back to standard NMS.
874
875
tflite_post_processing_config: An optional dictionary containing
875
876
post-processing parameters used for TFLite custom NMS op.
877
+ pre_nms_top_k_sharding_block: For v3 (edge tpu friendly) NMS, avoids
878
+ creating long axis for pre_nms_top_k. Will do top_k in shards of size
879
+ [num_classes, pre_nms_top_k_sharding_block * boxes_per_location]
876
880
877
881
**kwargs: Additional keyword arguments passed to Layer.
878
882
"""
@@ -886,11 +890,15 @@ def __init__(self,
886
890
'use_cpu_nms' : use_cpu_nms ,
887
891
'soft_nms_sigma' : soft_nms_sigma
888
892
}
893
+ # Don't store if were not defined
894
+ if pre_nms_top_k_sharding_block is not None :
895
+ self ._config_dict [
896
+ 'pre_nms_top_k_sharding_block' ] = pre_nms_top_k_sharding_block
889
897
890
898
if tflite_post_processing_config is not None :
891
899
self ._config_dict .update (
892
900
{'tflite_post_processing_config' : tflite_post_processing_config })
893
- super (MultilevelDetectionGenerator , self ).__init__ (** kwargs )
901
+ super ().__init__ (** kwargs )
894
902
895
903
def _decode_multilevel_outputs (
896
904
self ,
@@ -969,12 +977,74 @@ def _decode_multilevel_outputs(
969
977
970
978
return boxes , scores , attributes
971
979
972
- def __call__ (self ,
973
- raw_boxes : Mapping [str , tf .Tensor ],
974
- raw_scores : Mapping [str , tf .Tensor ],
975
- anchor_boxes : Mapping [str , tf .Tensor ],
976
- image_shape : tf .Tensor ,
977
- raw_attributes : Optional [Mapping [str , tf .Tensor ]] = None ):
980
+ def _decode_multilevel_outputs_and_pre_nms_top_k (
981
+ self , raw_boxes : Mapping [str , tf .Tensor ],
982
+ raw_scores : Mapping [str , tf .Tensor ], anchor_boxes : Mapping [str ,
983
+ tf .Tensor ],
984
+ image_shape : tf .Tensor ) -> Tuple [tf .Tensor , tf .Tensor ]:
985
+ """Collects dict of multilevel boxes, scores into lists."""
986
+ boxes = None
987
+ scores = None
988
+
989
+ pre_nms_top_k = self ._config_dict ['pre_nms_top_k' ]
990
+ # TODO(b/258007436): consider removing when compiler be able to handle
991
+ # it on its own.
992
+ pre_nms_top_k_sharding_block = self ._config_dict .get (
993
+ 'pre_nms_top_k_sharding_block' , 128 )
994
+ levels = list (raw_boxes .keys ())
995
+ min_level = int (min (levels ))
996
+ max_level = int (max (levels ))
997
+ for i in range (max_level , min_level - 1 , - 1 ):
998
+ (_ , unsharded_h , unsharded_w , num_anchors_per_locations_times_4
999
+ ) = raw_boxes [str (i )].get_shape ().as_list ()
1000
+ block = max (1 , pre_nms_top_k_sharding_block // unsharded_w )
1001
+ anchor_boxes_unsharded = tf .reshape (
1002
+ anchor_boxes [str (i )],
1003
+ [1 , unsharded_h , unsharded_w , num_anchors_per_locations_times_4 ])
1004
+ for (raw_scores_i , raw_boxes_i , anchor_boxes_i ) in edgetpu .shard_tensors (
1005
+ 1 , block ,
1006
+ (raw_scores [str (i )], raw_boxes [str (i )], anchor_boxes_unsharded )):
1007
+ batch_size = tf .shape (raw_boxes_i )[0 ]
1008
+ (_ , feature_h_i , feature_w_i , _ ) = raw_boxes_i .get_shape ().as_list ()
1009
+ num_locations = feature_h_i * feature_w_i
1010
+ num_anchors_per_locations = num_anchors_per_locations_times_4 // 4
1011
+ num_classes = raw_scores_i .get_shape ().as_list (
1012
+ )[- 1 ] // num_anchors_per_locations
1013
+
1014
+ # Applies score transformation and remove the implicit background class.
1015
+ scores_i = tf .slice (
1016
+ tf .transpose (
1017
+ tf .reshape (raw_scores_i , [
1018
+ batch_size , num_locations * num_anchors_per_locations ,
1019
+ num_classes
1020
+ ]), [0 , 2 , 1 ]), [0 , 1 , 0 ], [- 1 , - 1 , - 1 ])
1021
+
1022
+ # Box decoding.
1023
+ # The anchor boxes are shared for all data in a batch.
1024
+ # One stage detector only supports class agnostic box regression.
1025
+ boxes_shape = [batch_size , num_locations * num_anchors_per_locations , 4 ]
1026
+ boxes_i = tf .tile (
1027
+ tf .expand_dims (
1028
+ box_ops .decode_boxes (
1029
+ tf .reshape (raw_boxes_i , boxes_shape ),
1030
+ tf .reshape (anchor_boxes_i , boxes_shape )),
1031
+ axis = 1 ), [1 , num_classes - 1 , 1 , 1 ])
1032
+ scores , boxes = edgetpu .concat_and_top_k (pre_nms_top_k ,
1033
+ (scores , scores_i ),
1034
+ (boxes , boxes_i ))
1035
+
1036
+ return (box_ops .clip_boxes (boxes ,
1037
+ tf .expand_dims (image_shape ,
1038
+ axis = 1 )), tf .sigmoid (scores ))
1039
+
1040
+ def __call__ (
1041
+ self ,
1042
+ raw_boxes : Mapping [str , tf .Tensor ],
1043
+ raw_scores : Mapping [str , tf .Tensor ],
1044
+ anchor_boxes : Mapping [str , tf .Tensor ],
1045
+ image_shape : tf .Tensor ,
1046
+ raw_attributes : Optional [Mapping [str , tf .Tensor ]] = None
1047
+ ) -> Mapping [str , Any ]:
978
1048
"""Generates final detections.
979
1049
980
1050
Args:
@@ -1031,8 +1101,13 @@ def __call__(self,
1031
1101
'detection_scores' : scores
1032
1102
}
1033
1103
1034
- boxes , scores , attributes = self ._decode_multilevel_outputs (
1035
- raw_boxes , raw_scores , anchor_boxes , image_shape , raw_attributes )
1104
+ if self ._config_dict ['nms_version' ] != 'v3' :
1105
+ boxes , scores , attributes = self ._decode_multilevel_outputs (
1106
+ raw_boxes , raw_scores , anchor_boxes , image_shape , raw_attributes )
1107
+ else :
1108
+ attributes = None
1109
+ boxes , scores = self ._decode_multilevel_outputs_and_pre_nms_top_k (
1110
+ raw_boxes , raw_scores , anchor_boxes , image_shape )
1036
1111
1037
1112
if not self ._config_dict ['apply_nms' ]:
1038
1113
return {
@@ -1086,17 +1161,15 @@ def __call__(self,
1086
1161
# Set `nmsed_attributes` to None for v2.
1087
1162
nmsed_attributes = {}
1088
1163
elif self ._config_dict ['nms_version' ] == 'v3' :
1089
- # TODO(tohaspiridonov): add compatible version of
1090
- # `_decode_multilevel_outputs` in cl/485381750
1091
1164
(nmsed_boxes , nmsed_scores , nmsed_classes , valid_detections ) = (
1092
1165
_generate_detections_v3 (
1093
- tf . transpose ( boxes , [ 0 , 2 , 1 , 3 ]) ,
1094
- tf . transpose ( scores , [ 0 , 2 , 1 ]) ,
1166
+ boxes ,
1167
+ scores ,
1095
1168
pre_nms_score_threshold = self
1096
1169
._config_dict ['pre_nms_score_threshold' ],
1097
1170
nms_iou_threshold = self ._config_dict ['nms_iou_threshold' ],
1098
1171
max_num_detections = self ._config_dict ['max_num_detections' ]))
1099
- # Set `nmsed_attributes` to None for v2 .
1172
+ # Set `nmsed_attributes` to None for v3 .
1100
1173
nmsed_attributes = {}
1101
1174
else :
1102
1175
raise ValueError ('NMS version {} not supported.' .format (
0 commit comments