@@ -382,7 +382,8 @@ def _generate_detections_v3(
382
382
scores : tf .Tensor ,
383
383
pre_nms_score_threshold : float = 0.05 ,
384
384
nms_iou_threshold : float = 0.5 ,
385
- max_num_detections : int = 100
385
+ max_num_detections : int = 100 ,
386
+ refinements : int = 2 ,
386
387
) -> Tuple [tf .Tensor , tf .Tensor , tf .Tensor , tf .Tensor ]:
387
388
"""Generates the detections given the model outputs using NMS for EdgeTPU.
388
389
@@ -400,6 +401,7 @@ def _generate_detections_v3(
400
401
boxes overlap too much with respect to IOU.
401
402
max_num_detections: A `scalar` representing maximum number of boxes retained
402
403
over all classes.
404
+ refinements: Quality parameter for NMS algorithm.
403
405
404
406
Returns:
405
407
nms_boxes: A `float` tf.Tensor of shape [batch_size, max_num_detections, 4]
@@ -434,10 +436,8 @@ def _generate_detections_v3(
434
436
435
437
# EdgeTPU-friendly class-wise NMS, -1 for invalid.
436
438
indices = edgetpu .non_max_suppression_padded (
437
- boxes ,
438
- scores ,
439
- max_num_detections ,
440
- iou_threshold = nms_iou_threshold )
439
+ boxes , scores , max_num_detections , iou_threshold = nms_iou_threshold ,
440
+ refinements = refinements )
441
441
# Gather NMS-ed boxes and scores.
442
442
safe_indices = tf .nn .relu (indices ) # 0 for invalid
443
443
invalid_detections = safe_indices - indices # 1 for invalid, 0 for valid
@@ -859,6 +859,7 @@ def __init__(self,
859
859
soft_nms_sigma : Optional [float ] = None ,
860
860
tflite_post_processing_config : Optional [Dict [str , Any ]] = None ,
861
861
pre_nms_top_k_sharding_block : Optional [int ] = None ,
862
+ nms_v3_refinements : Optional [int ] = None ,
862
863
** kwargs ):
863
864
"""Initializes a multi-level detection generator.
864
865
@@ -882,6 +883,12 @@ def __init__(self,
882
883
pre_nms_top_k_sharding_block: For v3 (edge tpu friendly) NMS, avoids
883
884
creating long axis for pre_nms_top_k. Will do top_k in shards of size
884
885
[num_classes, pre_nms_top_k_sharding_block * boxes_per_location]
886
+ nms_v3_refinements: For v3 (edge tpu friendly) NMS, sets how close result
887
+ should be to standard NMS. When None, 2 is used. Here is some
888
+ experimental deviations for different refinement values:
889
+ if == 0, AP is reduced 1.0%, AR is reduced 5% on COCO
890
+ if == 1, AP is reduced 0.2%, AR is reduced 2% on COCO
891
+ if == 2, AP is reduced <0.1%, AR is reduced <1% on COCO
885
892
886
893
**kwargs: Additional keyword arguments passed to Layer.
887
894
"""
@@ -899,6 +906,9 @@ def __init__(self,
899
906
if pre_nms_top_k_sharding_block is not None :
900
907
self ._config_dict [
901
908
'pre_nms_top_k_sharding_block' ] = pre_nms_top_k_sharding_block
909
+ if nms_v3_refinements is not None :
910
+ self ._config_dict [
911
+ 'nms_v3_refinements' ] = nms_v3_refinements
902
912
903
913
if tflite_post_processing_config is not None :
904
914
self ._config_dict .update (
@@ -999,22 +1009,26 @@ def _decode_multilevel_outputs_and_pre_nms_top_k(
999
1009
levels = list (raw_boxes .keys ())
1000
1010
min_level = int (min (levels ))
1001
1011
max_level = int (max (levels ))
1012
+ clip_shape = tf .expand_dims (tf .expand_dims (image_shape , axis = 1 ), axis = 1 )
1002
1013
for i in range (max_level , min_level - 1 , - 1 ):
1003
1014
(batch_size , unsharded_h , unsharded_w , num_anchors_per_locations_times_4
1004
1015
) = raw_boxes [str (i )].get_shape ().as_list ()
1016
+ num_anchors_per_locations = num_anchors_per_locations_times_4 // 4
1005
1017
if batch_size is None :
1006
1018
batch_size = tf .shape (raw_boxes [str (i )])[0 ]
1007
1019
block = max (1 , pre_nms_top_k_sharding_block // unsharded_w )
1008
- anchor_boxes_unsharded = tf .reshape (anchor_boxes [str (i )], [
1009
- batch_size , unsharded_h , unsharded_w ,
1010
- num_anchors_per_locations_times_4
1011
- ])
1012
- for (raw_scores_i , raw_boxes_i , anchor_boxes_i ) in edgetpu .shard_tensors (
1020
+ boxes_shape = [
1021
+ batch_size , unsharded_h , unsharded_w * num_anchors_per_locations , 4
1022
+ ]
1023
+ decoded_boxes = box_ops .clip_boxes (
1024
+ box_ops .decode_boxes (
1025
+ tf .reshape (raw_boxes [str (i )], boxes_shape ),
1026
+ tf .reshape (anchor_boxes [str (i )], boxes_shape )), clip_shape )
1027
+ for (raw_scores_i , decoded_boxes_i ) in edgetpu .shard_tensors (
1013
1028
1 , block ,
1014
- (raw_scores [str (i )], raw_boxes [ str ( i )], anchor_boxes_unsharded )):
1015
- (_ , feature_h_i , feature_w_i , _ ) = raw_boxes_i .get_shape ().as_list ()
1029
+ (raw_scores [str (i )], decoded_boxes )):
1030
+ (_ , feature_h_i , feature_w_i , _ ) = raw_scores_i .get_shape ().as_list ()
1016
1031
num_locations = feature_h_i * feature_w_i
1017
- num_anchors_per_locations = num_anchors_per_locations_times_4 // 4
1018
1032
num_classes = raw_scores_i .get_shape ().as_list (
1019
1033
)[- 1 ] // num_anchors_per_locations
1020
1034
@@ -1029,18 +1043,16 @@ def _decode_multilevel_outputs_and_pre_nms_top_k(
1029
1043
# Box decoding.
1030
1044
# The anchor boxes are shared for all data in a batch.
1031
1045
# One stage detector only supports class agnostic box regression.
1032
- boxes_shape = [batch_size , num_locations * num_anchors_per_locations , 4 ]
1033
1046
boxes_i = tf .tile (
1034
- tf .expand_dims (
1035
- box_ops .decode_boxes (
1036
- tf .reshape (raw_boxes_i , boxes_shape ),
1037
- tf .reshape (anchor_boxes_i , boxes_shape )),
1038
- axis = 1 ), [1 , num_classes - 1 , 1 , 1 ])
1047
+ tf .reshape (
1048
+ decoded_boxes_i ,
1049
+ [batch_size , 1 , num_locations * num_anchors_per_locations , 4 ]),
1050
+ [1 , num_classes - 1 , 1 , 1 ])
1039
1051
scores , boxes = edgetpu .concat_and_top_k (pre_nms_top_k ,
1040
1052
(scores , scores_i ),
1041
1053
(boxes , boxes_i ))
1042
- clip_shape = tf .expand_dims ( tf . expand_dims ( image_shape , axis = 1 ), axis = 1 )
1043
- return box_ops . clip_boxes ( boxes , clip_shape ) , tf .sigmoid (scores )
1054
+ boxes : tf .Tensor = boxes # pytype: disable=annotation-type-mismatch
1055
+ return boxes , tf .sigmoid (scores )
1044
1056
1045
1057
def __call__ (
1046
1058
self ,
@@ -1173,7 +1185,8 @@ def __call__(
1173
1185
pre_nms_score_threshold = self
1174
1186
._config_dict ['pre_nms_score_threshold' ],
1175
1187
nms_iou_threshold = self ._config_dict ['nms_iou_threshold' ],
1176
- max_num_detections = self ._config_dict ['max_num_detections' ]))
1188
+ max_num_detections = self ._config_dict ['max_num_detections' ],
1189
+ refinements = self ._config_dict .get ('nms_v3_refinements' , 2 )))
1177
1190
# Set `nmsed_attributes` to None for v3.
1178
1191
nmsed_attributes = {}
1179
1192
else :
0 commit comments