Skip to content

Commit 9bcbe96

Browse files
Internal change
PiperOrigin-RevId: 481231095
1 parent f23d7bc commit 9bcbe96

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

official/projects/panoptic/tasks/panoptic_maskrcnn.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ def build_metrics(self, training: bool = True) -> List[
244244
dtype=tf.float32)
245245

246246
else:
247-
self._build_coco_metrics()
247+
if self.task_config.use_coco_metrics:
248+
self._build_coco_metrics()
248249

249250
rescale_predictions = (not self.task_config.validation_data.parser
250251
.segmentation_resize_eval_groundtruth)
@@ -366,24 +367,25 @@ def validation_step(self,
366367
training=False)
367368

368369
logs = {self.loss: 0}
369-
coco_model_outputs = {
370-
'detection_masks': outputs['detection_masks'],
371-
'detection_boxes': outputs['detection_boxes'],
372-
'detection_scores': outputs['detection_scores'],
373-
'detection_classes': outputs['detection_classes'],
374-
'num_detections': outputs['num_detections'],
375-
'source_id': labels['groundtruths']['source_id'],
376-
'image_info': labels['image_info']
377-
}
370+
if self._task_config.use_coco_metrics:
371+
coco_model_outputs = {
372+
'detection_masks': outputs['detection_masks'],
373+
'detection_boxes': outputs['detection_boxes'],
374+
'detection_scores': outputs['detection_scores'],
375+
'detection_classes': outputs['detection_classes'],
376+
'num_detections': outputs['num_detections'],
377+
'source_id': labels['groundtruths']['source_id'],
378+
'image_info': labels['image_info']
379+
}
380+
logs.update(
381+
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
382+
378383
segmentation_labels = {
379384
'masks': labels['groundtruths']['gt_segmentation_mask'],
380385
'valid_masks': labels['groundtruths']['gt_segmentation_valid_mask'],
381386
'image_info': labels['image_info']
382387
}
383388

384-
logs.update(
385-
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
386-
387389
self.segmentation_perclass_iou_metric.update_state(
388390
segmentation_labels, outputs['segmentation_outputs'])
389391

@@ -400,15 +402,18 @@ def validation_step(self,
400402

401403
def aggregate_logs(self, state=None, step_outputs=None):
402404
if state is None:
403-
self.coco_metric.reset_states()
404405
self.segmentation_perclass_iou_metric.reset_states()
405-
state = [self.coco_metric, self.segmentation_perclass_iou_metric]
406+
state = [self.segmentation_perclass_iou_metric]
407+
if self.task_config.use_coco_metrics:
408+
self.coco_metric.reset_states()
409+
state.append(self.coco_metric)
406410
if self.task_config.model.generate_panoptic_masks:
407-
state += [self.panoptic_quality_metric]
411+
self.panoptic_quality_metric.reset_states()
412+
state.append(self.panoptic_quality_metric)
408413

409-
self.coco_metric.update_state(
410-
step_outputs[self.coco_metric.name][0],
411-
step_outputs[self.coco_metric.name][1])
414+
if self.task_config.use_coco_metrics:
415+
self.coco_metric.update_state(step_outputs[self.coco_metric.name][0],
416+
step_outputs[self.coco_metric.name][1])
412417

413418
if self.task_config.model.generate_panoptic_masks:
414419
self.panoptic_quality_metric.update_state(

0 commit comments

Comments
 (0)