@@ -45,7 +45,8 @@ def __init__(self,
45
45
annotation_file ,
46
46
include_mask ,
47
47
need_rescale_bboxes = True ,
48
- per_category_metrics = False ):
48
+ per_category_metrics = False ,
49
+ max_num_eval_detections = 100 ):
49
50
"""Constructs COCO evaluation class.
50
51
51
52
The class provides the interface to COCO metrics_fn. The
@@ -62,6 +63,10 @@ def __init__(self,
62
63
need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back
63
64
to absolute values (`image_info` is needed in this case).
64
65
per_category_metrics: Whether to return per category metrics.
66
+ max_num_eval_detections: Maximum number of detections to evaluate in coco
67
+ eval api. Default at 100.
68
+ Raises:
69
+ ValueError: if max_num_eval_detections is not an integer.
65
70
"""
66
71
if annotation_file :
67
72
if annotation_file .startswith ('gs://' ):
@@ -78,10 +83,14 @@ def __init__(self,
78
83
self ._annotation_file = annotation_file
79
84
self ._include_mask = include_mask
80
85
self ._per_category_metrics = per_category_metrics
86
+ if max_num_eval_detections is None or not isinstance (
87
+ max_num_eval_detections , int ):
88
+ raise ValueError ('max_num_eval_detections must be an integer.' )
81
89
self ._metric_names = [
82
90
'AP' , 'AP50' , 'AP75' , 'APs' , 'APm' , 'APl' , 'ARmax1' , 'ARmax10' ,
83
- 'ARmax100 ' , 'ARs' , 'ARm' , 'ARl'
91
+ f'ARmax { max_num_eval_detections } ' , 'ARs' , 'ARm' , 'ARl'
84
92
]
93
+ self .max_num_eval_detections = max_num_eval_detections
85
94
self ._required_prediction_fields = [
86
95
'source_id' , 'num_detections' , 'detection_classes' , 'detection_scores' ,
87
96
'detection_boxes'
@@ -141,6 +150,7 @@ def evaluate(self):
141
150
142
151
coco_eval = cocoeval .COCOeval (coco_gt , coco_dt , iouType = 'bbox' )
143
152
coco_eval .params .imgIds = image_ids
153
+ coco_eval .params .maxDets [2 ] = self .max_num_eval_detections
144
154
coco_eval .evaluate ()
145
155
coco_eval .accumulate ()
146
156
coco_eval .summarize ()
0 commit comments