@@ -114,6 +114,10 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
114
114
'groundtruth_not_exhaustive_classes': [batch_size, num_classes] K-hot
115
115
representation of 1-indexed classes which don't have all of their
116
116
instances marked exhaustively.
117
+ 'input_data_fields.groundtruth_image_classes': integer representation of
118
+ the classes that were sent for verification for a given image. Note that
119
+ this field does not support batching as the number of classes can be
120
+ variable.
117
121
class_agnostic: Boolean indicating whether detections are class agnostic.
118
122
"""
119
123
input_data_fields = fields .InputDataFields ()
@@ -136,6 +140,18 @@ def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
136
140
input_data_fields .groundtruth_classes : groundtruth_classes
137
141
}
138
142
143
+ if detection_model .groundtruth_has_field (
144
+ input_data_fields .groundtruth_image_classes ):
145
+ groundtruth_image_classes_k_hot = tf .stack (
146
+ detection_model .groundtruth_lists (
147
+ input_data_fields .groundtruth_image_classes ))
148
+ # We do not add label_id_offset here because it was not added when encoding
149
+ # groundtruth_image_classes.
150
+ groundtruth_image_classes = tf .expand_dims (
151
+ tf .where (groundtruth_image_classes_k_hot > 0 )[:, 1 ], 0 )
152
+ groundtruth [
153
+ input_data_fields .groundtruth_image_classes ] = groundtruth_image_classes
154
+
139
155
if detection_model .groundtruth_has_field (fields .BoxListFields .masks ):
140
156
groundtruth [input_data_fields .groundtruth_instance_masks ] = tf .stack (
141
157
detection_model .groundtruth_lists (fields .BoxListFields .masks ))
@@ -384,6 +400,10 @@ def provide_groundtruth(model, labels, training_step=None):
384
400
if fields .InputDataFields .groundtruth_not_exhaustive_classes in labels :
385
401
gt_not_exhaustive_classes = labels [
386
402
fields .InputDataFields .groundtruth_not_exhaustive_classes ]
403
+ groundtruth_image_classes = None
404
+ if fields .InputDataFields .groundtruth_image_classes in labels :
405
+ groundtruth_image_classes = labels [
406
+ fields .InputDataFields .groundtruth_image_classes ]
387
407
model .provide_groundtruth (
388
408
groundtruth_boxes_list = gt_boxes_list ,
389
409
groundtruth_classes_list = gt_classes_list ,
@@ -405,6 +425,7 @@ def provide_groundtruth(model, labels, training_step=None):
405
425
groundtruth_not_exhaustive_classes = gt_not_exhaustive_classes ,
406
426
groundtruth_keypoint_depths_list = gt_keypoint_depths_list ,
407
427
groundtruth_keypoint_depth_weights_list = gt_keypoint_depth_weights_list ,
428
+ groundtruth_image_classes = groundtruth_image_classes ,
408
429
training_step = training_step )
409
430
410
431
0 commit comments