Skip to content

Commit 0d85d22

Browse files
tensorflower-gardenerfyangf
authored andcommitted
Internal change
PiperOrigin-RevId: 487291539
1 parent 056cd61 commit 0d85d22

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

official/vision/configs/retinanet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ class RetinaNetTask(cfg.TaskConfig):
172172
# TODO(crisnv) Add paper link when available.
173173
freeze_backbone: bool = False
174174

175+
# Sets maximum number of boxes to be evaluated by coco eval api.
176+
max_num_eval_detections: int = 100
177+
175178

176179
@exp_factory.register_config_factory('retinanet')
177180
def retinanet() -> cfg.ExperimentConfig:

official/vision/evaluation/coco_evaluator.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def __init__(self,
4545
annotation_file,
4646
include_mask,
4747
need_rescale_bboxes=True,
48-
per_category_metrics=False):
48+
per_category_metrics=False,
49+
max_num_eval_detections=100):
4950
"""Constructs COCO evaluation class.
5051
5152
The class provides the interface to COCO metrics_fn. The
@@ -62,6 +63,10 @@ def __init__(self,
6263
need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back
6364
to absolute values (`image_info` is needed in this case).
6465
per_category_metrics: Whether to return per category metrics.
66+
max_num_eval_detections: Maximum number of detections to evaluate in coco
67+
eval api. Default at 100.
68+
Raises:
69+
ValueError: if max_num_eval_detections is not an integer.
6570
"""
6671
if annotation_file:
6772
if annotation_file.startswith('gs://'):
@@ -78,10 +83,14 @@ def __init__(self,
7883
self._annotation_file = annotation_file
7984
self._include_mask = include_mask
8085
self._per_category_metrics = per_category_metrics
86+
if max_num_eval_detections is None or not isinstance(
87+
max_num_eval_detections, int):
88+
raise ValueError('max_num_eval_detections must be an integer.')
8189
self._metric_names = [
8290
'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'ARmax1', 'ARmax10',
83-
'ARmax100', 'ARs', 'ARm', 'ARl'
91+
f'ARmax{max_num_eval_detections}', 'ARs', 'ARm', 'ARl'
8492
]
93+
self.max_num_eval_detections = max_num_eval_detections
8594
self._required_prediction_fields = [
8695
'source_id', 'num_detections', 'detection_classes', 'detection_scores',
8796
'detection_boxes'
@@ -141,6 +150,7 @@ def evaluate(self):
141150

142151
coco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='bbox')
143152
coco_eval.params.imgIds = image_ids
153+
coco_eval.params.maxDets[2] = self.max_num_eval_detections
144154
coco_eval.evaluate()
145155
coco_eval.accumulate()
146156
coco_eval.summarize()

official/vision/tasks/retinanet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ def build_metrics(self, training: bool = True):
246246
self.coco_metric = coco_evaluator.COCOEvaluator(
247247
annotation_file=self.task_config.annotation_file,
248248
include_mask=False,
249-
per_category_metrics=self.task_config.per_category_metrics)
249+
per_category_metrics=self.task_config.per_category_metrics,
250+
max_num_eval_detections=self.task_config.max_num_eval_detections)
250251
if self._task_config.use_wod_metrics:
251252
# To use Waymo open dataset metrics, please install one of the pip
252253
# package `waymo-open-dataset-tf-*` from

0 commit comments

Comments
 (0)