@@ -244,7 +244,8 @@ def build_metrics(self, training: bool = True) -> List[
244
244
dtype = tf .float32 )
245
245
246
246
else :
247
- self ._build_coco_metrics ()
247
+ if self .task_config .use_coco_metrics :
248
+ self ._build_coco_metrics ()
248
249
249
250
rescale_predictions = (not self .task_config .validation_data .parser
250
251
.segmentation_resize_eval_groundtruth )
@@ -366,24 +367,25 @@ def validation_step(self,
366
367
training = False )
367
368
368
369
logs = {self .loss : 0 }
369
- coco_model_outputs = {
370
- 'detection_masks' : outputs ['detection_masks' ],
371
- 'detection_boxes' : outputs ['detection_boxes' ],
372
- 'detection_scores' : outputs ['detection_scores' ],
373
- 'detection_classes' : outputs ['detection_classes' ],
374
- 'num_detections' : outputs ['num_detections' ],
375
- 'source_id' : labels ['groundtruths' ]['source_id' ],
376
- 'image_info' : labels ['image_info' ]
377
- }
370
+ if self ._task_config .use_coco_metrics :
371
+ coco_model_outputs = {
372
+ 'detection_masks' : outputs ['detection_masks' ],
373
+ 'detection_boxes' : outputs ['detection_boxes' ],
374
+ 'detection_scores' : outputs ['detection_scores' ],
375
+ 'detection_classes' : outputs ['detection_classes' ],
376
+ 'num_detections' : outputs ['num_detections' ],
377
+ 'source_id' : labels ['groundtruths' ]['source_id' ],
378
+ 'image_info' : labels ['image_info' ]
379
+ }
380
+ logs .update (
381
+ {self .coco_metric .name : (labels ['groundtruths' ], coco_model_outputs )})
382
+
378
383
segmentation_labels = {
379
384
'masks' : labels ['groundtruths' ]['gt_segmentation_mask' ],
380
385
'valid_masks' : labels ['groundtruths' ]['gt_segmentation_valid_mask' ],
381
386
'image_info' : labels ['image_info' ]
382
387
}
383
388
384
- logs .update (
385
- {self .coco_metric .name : (labels ['groundtruths' ], coco_model_outputs )})
386
-
387
389
self .segmentation_perclass_iou_metric .update_state (
388
390
segmentation_labels , outputs ['segmentation_outputs' ])
389
391
@@ -400,15 +402,18 @@ def validation_step(self,
400
402
401
403
def aggregate_logs (self , state = None , step_outputs = None ):
402
404
if state is None :
403
- self .coco_metric .reset_states ()
404
405
self .segmentation_perclass_iou_metric .reset_states ()
405
- state = [self .coco_metric , self .segmentation_perclass_iou_metric ]
406
+ state = [self .segmentation_perclass_iou_metric ]
407
+ if self .task_config .use_coco_metrics :
408
+ self .coco_metric .reset_states ()
409
+ state .append (self .coco_metric )
406
410
if self .task_config .model .generate_panoptic_masks :
407
- state += [self .panoptic_quality_metric ]
411
+ self .panoptic_quality_metric .reset_states ()
412
+ state .append (self .panoptic_quality_metric )
408
413
409
- self .coco_metric . update_state (
410
- step_outputs [self .coco_metric .name ][0 ],
411
- step_outputs [self .coco_metric .name ][1 ])
414
+ if self .task_config . use_coco_metrics :
415
+ self . coco_metric . update_state ( step_outputs [self .coco_metric .name ][0 ],
416
+ step_outputs [self .coco_metric .name ][1 ])
412
417
413
418
if self .task_config .model .generate_panoptic_masks :
414
419
self .panoptic_quality_metric .update_state (
0 commit comments