Skip to content

Commit 00f71bf

Browse files
tensorflower-gardenerTF Object Detection Team
authored andcommitted
Image-level labels are not propagated to the open images challenge metric.
PiperOrigin-RevId: 417394640
1 parent 48b4b57 commit 00f71bf

File tree

4 files changed

+34
-1
lines changed

4 files changed

+34
-1
lines changed

research/object_detection/core/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def provide_groundtruth(
327327
groundtruth_not_exhaustive_classes=None,
328328
groundtruth_keypoint_depths_list=None,
329329
groundtruth_keypoint_depth_weights_list=None,
330+
groundtruth_image_classes=None,
330331
training_step=None):
331332
"""Provide groundtruth tensors.
332333
@@ -398,6 +399,9 @@ def provide_groundtruth(
398399
groundtruth_keypoint_depth_weights_list: a list of 2-D tf.float32 tensors
399400
of shape [num_boxes, num_keypoints] containing the weights of the
400401
relative depths.
402+
groundtruth_image_classes: A list of 1-D tf.float32 tensors of shape
403+
[num_classes], containing label indices encoded as k-hot of the classes
404+
that are present or not present in the image.
401405
training_step: An integer denoting the current training step. This is
402406
useful when models want to anneal loss terms.
403407
"""
@@ -474,6 +478,10 @@ def provide_groundtruth(
474478
self._groundtruth_lists[
475479
fields.InputDataFields
476480
.groundtruth_verified_neg_classes] = groundtruth_verified_neg_classes
481+
if groundtruth_image_classes:
482+
self._groundtruth_lists[
483+
fields.InputDataFields
484+
.groundtruth_image_classes] = groundtruth_image_classes
477485
if groundtruth_not_exhaustive_classes:
478486
self._groundtruth_lists[
479487
fields.InputDataFields

research/object_detection/inputs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,8 @@ def _get_labels_dict(input_dict):
668668
fields.InputDataFields.groundtruth_dp_surface_coords,
669669
fields.InputDataFields.groundtruth_track_ids,
670670
fields.InputDataFields.groundtruth_verified_neg_classes,
671-
fields.InputDataFields.groundtruth_not_exhaustive_classes
671+
fields.InputDataFields.groundtruth_not_exhaustive_classes,
672+
fields.InputDataFields.groundtruth_image_classes,
672673
]
673674

674675
for key in optional_label_keys:

research/object_detection/model_lib.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
114114
'groundtruth_not_exhaustive_classes': [batch_size, num_classes] K-hot
115115
representation of 1-indexed classes which don't have all of their
116116
instances marked exhaustively.
117+
'input_data_fields.groundtruth_image_classes': integer representation of
118+
the classes that were sent for verification for a given image. Note that
119+
this field does not support batching as the number of classes can be
120+
variable.
117121
class_agnostic: Boolean indicating whether detections are class agnostic.
118122
"""
119123
input_data_fields = fields.InputDataFields()
@@ -136,6 +140,18 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
136140
input_data_fields.groundtruth_classes: groundtruth_classes
137141
}
138142

143+
if detection_model.groundtruth_has_field(
144+
input_data_fields.groundtruth_image_classes):
145+
groundtruth_image_classes_k_hot = tf.stack(
146+
detection_model.groundtruth_lists(
147+
input_data_fields.groundtruth_image_classes))
148+
# We do not add label_id_offset here because it was not added when encoding
149+
# groundtruth_image_classes.
150+
groundtruth_image_classes = tf.expand_dims(
151+
tf.where(groundtruth_image_classes_k_hot > 0)[:, 1], 0)
152+
groundtruth[
153+
input_data_fields.groundtruth_image_classes] = groundtruth_image_classes
154+
139155
if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
140156
groundtruth[input_data_fields.groundtruth_instance_masks] = tf.stack(
141157
detection_model.groundtruth_lists(fields.BoxListFields.masks))
@@ -384,6 +400,10 @@ def provide_groundtruth(model, labels, training_step=None):
384400
if fields.InputDataFields.groundtruth_not_exhaustive_classes in labels:
385401
gt_not_exhaustive_classes = labels[
386402
fields.InputDataFields.groundtruth_not_exhaustive_classes]
403+
groundtruth_image_classes = None
404+
if fields.InputDataFields.groundtruth_image_classes in labels:
405+
groundtruth_image_classes = labels[
406+
fields.InputDataFields.groundtruth_image_classes]
387407
model.provide_groundtruth(
388408
groundtruth_boxes_list=gt_boxes_list,
389409
groundtruth_classes_list=gt_classes_list,
@@ -405,6 +425,7 @@ def provide_groundtruth(model, labels, training_step=None):
405425
groundtruth_not_exhaustive_classes=gt_not_exhaustive_classes,
406426
groundtruth_keypoint_depths_list=gt_keypoint_depths_list,
407427
groundtruth_keypoint_depth_weights_list=gt_keypoint_depth_weights_list,
428+
groundtruth_image_classes=groundtruth_image_classes,
408429
training_step=training_step)
409430

410431

research/object_detection/utils/object_detection_evaluation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,9 @@ def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
870870
image_classes = groundtruth_dict[input_fields.groundtruth_image_classes]
871871
elif input_fields.groundtruth_labeled_classes in groundtruth_dict:
872872
image_classes = groundtruth_dict[input_fields.groundtruth_labeled_classes]
873+
else:
874+
logging.warning('No image classes field found for image with id %s!',
875+
image_id)
873876
image_classes -= self._label_id_offset
874877
self._evaluatable_labels[image_id] = np.unique(
875878
np.concatenate((image_classes, groundtruth_classes)))

0 commit comments

Comments
 (0)