@@ -19,6 +19,7 @@ limitations under the License.
1919#include < limits>
2020#include < vector>
2121
22+ #include < glog/logging.h>
2223#include " external/com_google_absl/absl/memory/memory.h"
2324#include " external/com_google_absl/absl/status/status.h"
2425#include " external/com_google_absl/absl/strings/str_format.h"
@@ -56,11 +57,15 @@ using ::tflite::support::CreateStatusWithPayload;
5657using ::tflite::support::StatusOr;
5758using ::tflite::support::TfLiteSupportStatus;
5859using ::tflite::task::core::AssertAndReturnTypedTensor;
60+ using ::tflite::task::core::FindIndexByMetadataTensorName;
5961using ::tflite::task::core::TaskAPIFactory;
6062using ::tflite::task::core::TfLiteEngine;
6163
6264// The expected number of dimensions of the 4 output tensors, representing in
63- // that order: locations, classes, scores, num_results.
65+ // that order: locations, categories, scores, num_results. The order is
66+ // coming from the TFLite custom NMS op for object detection post-processing
67+ // shown in
68+ // https://github.com/tensorflow/tensorflow/blob/1c419b231b622bd9e9685682545e9064f0fbb42a/tensorflow/lite/kernels/detection_postprocess.cc#L47.
6469static constexpr int kOutputTensorsExpectedDims [4 ] = {3 , 2 , 2 , 1 };
6570constexpr int kDefaultLocationsIndex = 0 ;
6671constexpr int kDefaultClassesIndex = 1 ;
@@ -69,6 +74,11 @@ constexpr int kDefaultNumResultsIndex = 3;
6974
7075constexpr float kDefaultScoreThreshold = std::numeric_limits<float >::lowest();
7176
77+ constexpr char kLocationTensorName [] = " location" ;
78+ constexpr char kCategoryTensorName [] = " category" ;
79+ constexpr char kScoreTensorName [] = " score" ;
80+ constexpr char kNumberOfDetectionsTensorName [] = " number of detections" ;
81+
7282StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties (
7383 const TensorMetadata& tensor_metadata) {
7484 if (tensor_metadata.content () == nullptr ||
@@ -166,8 +176,38 @@ StatusOr<float> GetScoreThreshold(
166176 ->global_score_threshold ();
167177}
168178
179+ // Use tensor names in metadata to get the output order.
180+ std::vector<int > GetOutputIndices (
181+ const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
182+ tensor_metadatas) {
183+ std::vector<int > output_indices = {
184+ FindIndexByMetadataTensorName (tensor_metadatas, kLocationTensorName ),
185+ FindIndexByMetadataTensorName (tensor_metadatas, kCategoryTensorName ),
186+ FindIndexByMetadataTensorName (tensor_metadatas, kScoreTensorName ),
187+ FindIndexByMetadataTensorName (tensor_metadatas,
188+ kNumberOfDetectionsTensorName )};
189+
190+ for (int i = 0 ; i < 4 ; i++) {
191+ int output_index = output_indices[i];
192+ // If tensor name is not found, set the default output indices.
193+ if (output_index == -1 ) {
194+ LOG (WARNING) << absl::StrFormat (
195+ " You don't seem to be matching tensor names in metadata list. The "
196+ " tensor name \" %s\" at index %d in the model metadata doesn't match "
197+ " the available output names: [\" %s\" , \" %s\" , \" %s\" , \" %s\" ]." ,
198+ tensor_metadatas->Get (i)->name ()->c_str (), i, kLocationTensorName ,
199+ kCategoryTensorName , kScoreTensorName , kNumberOfDetectionsTensorName );
200+ output_indices = {kDefaultLocationsIndex , kDefaultClassesIndex ,
201+ kDefaultScoresIndex , kDefaultNumResultsIndex };
202+ return output_indices;
203+ }
204+ }
205+ return output_indices;
206+ }
207+
169208absl::Status SanityCheckOutputTensors (
170- const std::vector<const TfLiteTensor*>& output_tensors) {
209+ const std::vector<const TfLiteTensor*>& output_tensors,
210+ const std::vector<int >& output_indices) {
171211 if (output_tensors.size () != 4 ) {
172212 return CreateStatusWithPayload (
173213 StatusCode::kInternal ,
@@ -176,51 +216,53 @@ absl::Status SanityCheckOutputTensors(
176216 }
177217
178218 // Get number of results.
179- if (output_tensors[kDefaultNumResultsIndex ]->dims ->data [0 ] != 1 ) {
219+ const TfLiteTensor* num_results_tensor = output_tensors[output_indices[3 ]];
220+ if (num_results_tensor->dims ->data [0 ] != 1 ) {
180221 return CreateStatusWithPayload (
181222 StatusCode::kInternal ,
182223 absl::StrFormat (
183224 " Expected tensor with dimensions [1] at index 3, found [%d]" ,
184- output_tensors[ kDefaultNumResultsIndex ] ->dims ->data [0 ]));
225+ num_results_tensor ->dims ->data [0 ]));
185226 }
186- int num_results = static_cast <int >(AssertAndReturnTypedTensor< float >(
187- output_tensors[ kDefaultNumResultsIndex ] )[0 ]);
227+ int num_results = static_cast <int >(
228+ AssertAndReturnTypedTensor< float >(num_results_tensor )[0 ]);
188229
230+ const TfLiteTensor* location_tensor = output_tensors[output_indices[0 ]];
189231 // Check dimensions for the other tensors are correct.
190- if (output_tensors[ kDefaultLocationsIndex ] ->dims ->data [0 ] != 1 ||
191- output_tensors[ kDefaultLocationsIndex ] ->dims ->data [1 ] < num_results ||
192- output_tensors[ kDefaultLocationsIndex ] ->dims ->data [2 ] != 4 ) {
232+ if (location_tensor ->dims ->data [0 ] != 1 ||
233+ location_tensor ->dims ->data [1 ] < num_results ||
234+ location_tensor ->dims ->data [2 ] != 4 ) {
193235 return CreateStatusWithPayload (
194236 StatusCode::kInternal ,
195237 absl::StrFormat (
196238 " Expected locations tensor with dimensions [1, num_detected_boxes, "
197- " 4] at index %d , num_detected_boxes >= %d, found [%d,%d,%d]." ,
198- kDefaultLocationsIndex , num_results ,
199- output_tensors[ kDefaultLocationsIndex ] ->dims ->data [0 ],
200- output_tensors[ kDefaultLocationsIndex ]-> dims -> data [ 1 ],
201- output_tensors[ kDefaultLocationsIndex ]-> dims -> data [ 2 ]));
202- }
203- if (output_tensors[ kDefaultClassesIndex ] ->dims ->data [0 ] != 1 ||
204- output_tensors[ kDefaultClassesIndex ] ->dims ->data [1 ] < num_results) {
239+ " 4] at index 0 , num_detected_boxes >= %d, found [%d,%d,%d]." ,
240+ num_results, location_tensor-> dims -> data [ 0 ] ,
241+ location_tensor-> dims -> data [ 1 ], location_tensor ->dims ->data [2 ]));
242+ }
243+
244+ const TfLiteTensor* class_tensor = output_tensors[output_indices[ 1 ]];
245+ if (class_tensor ->dims ->data [0 ] != 1 ||
246+ class_tensor ->dims ->data [1 ] < num_results) {
205247 return CreateStatusWithPayload (
206248 StatusCode::kInternal ,
207249 absl::StrFormat (
208250 " Expected classes tensor with dimensions [1, num_detected_boxes] "
209- " at index %d, num_detected_boxes >= %d, found [%d,%d]." ,
210- kDefaultClassesIndex , num_results,
211- output_tensors[kDefaultClassesIndex ]->dims ->data [0 ],
212- output_tensors[kDefaultClassesIndex ]->dims ->data [1 ]));
251+ " at index 1, num_detected_boxes >= %d, found [%d,%d]." ,
252+ num_results, class_tensor->dims ->data [0 ],
253+ class_tensor->dims ->data [1 ]));
213254 }
214- if (output_tensors[kDefaultScoresIndex ]->dims ->data [0 ] != 1 ||
215- output_tensors[kDefaultScoresIndex ]->dims ->data [1 ] < num_results) {
255+
256+ const TfLiteTensor* scores_tensor = output_tensors[output_indices[2 ]];
257+ if (scores_tensor->dims ->data [0 ] != 1 ||
258+ scores_tensor->dims ->data [1 ] < num_results) {
216259 return CreateStatusWithPayload (
217260 StatusCode::kInternal ,
218261 absl::StrFormat (
219262 " Expected scores tensor with dimensions [1, num_detected_boxes] "
220- " at index %d, num_detected_boxes >= %d, found [%d,%d]." ,
221- kDefaultScoresIndex , num_results,
222- output_tensors[kDefaultScoresIndex ]->dims ->data [0 ],
223- output_tensors[kDefaultScoresIndex ]->dims ->data [1 ]));
263+ " at index 2, num_detected_boxes >= %d, found [%d,%d]." ,
264+ num_results, scores_tensor->dims ->data [0 ],
265+ scores_tensor->dims ->data [1 ]));
224266 }
225267
226268 return absl::OkStatus ();
@@ -409,25 +451,6 @@ absl::Status ObjectDetector::CheckAndSetOutputs() {
409451 TfLiteEngine::OutputCount (interpreter)),
410452 TfLiteSupportStatus::kInvalidNumOutputTensorsError );
411453 }
412- // Check tensor dimensions and batch size.
413- for (int i = 0 ; i < 4 ; ++i) {
414- const TfLiteTensor* tensor = TfLiteEngine::GetOutput (interpreter, i);
415- if (tensor->dims ->size != kOutputTensorsExpectedDims [i]) {
416- return CreateStatusWithPayload (
417- StatusCode::kInvalidArgument ,
418- absl::StrFormat (" Output tensor at index %d is expected to "
419- " have %d dimensions, found %d." ,
420- i, kOutputTensorsExpectedDims [i], tensor->dims ->size ),
421- TfLiteSupportStatus::kInvalidOutputTensorDimensionsError );
422- }
423- if (tensor->dims ->data [0 ] != 1 ) {
424- return CreateStatusWithPayload (
425- StatusCode::kInvalidArgument ,
426- absl::StrFormat (" Expected batch size of 1, found %d." ,
427- tensor->dims ->data [0 ]),
428- TfLiteSupportStatus::kInvalidOutputTensorDimensionsError );
429- }
430- }
431454
432455 // Now, perform sanity checks and extract metadata.
433456 const ModelMetadataExtractor* metadata_extractor =
@@ -455,10 +478,13 @@ absl::Status ObjectDetector::CheckAndSetOutputs() {
455478 TfLiteSupportStatus::kMetadataInconsistencyError );
456479 }
457480
481+ output_indices_ = GetOutputIndices (output_tensors_metadata);
482+
458483 // Extract mandatory BoundingBoxProperties for easier access at
459484 // post-processing time, performing sanity checks on the fly.
460485 ASSIGN_OR_RETURN (const BoundingBoxProperties* bounding_box_properties,
461- GetBoundingBoxProperties (*output_tensors_metadata->Get (0 )));
486+ GetBoundingBoxProperties (
487+ *output_tensors_metadata->Get (output_indices_[0 ])));
462488 if (bounding_box_properties->index () == nullptr ) {
463489 bounding_box_corners_order_ = {0 , 1 , 2 , 3 };
464490 } else {
@@ -474,16 +500,39 @@ absl::Status ObjectDetector::CheckAndSetOutputs() {
474500 // Build label map (if available) from metadata.
475501 ASSIGN_OR_RETURN (
476502 label_map_,
477- GetLabelMapIfAny (*metadata_extractor, *output_tensors_metadata->Get (1 ),
503+ GetLabelMapIfAny (*metadata_extractor,
504+ *output_tensors_metadata->Get (output_indices_[1 ]),
478505 options_->display_names_locale ()));
479506
480507 // Set score threshold.
481508 if (options_->has_score_threshold ()) {
482509 score_threshold_ = options_->score_threshold ();
483510 } else {
484- ASSIGN_OR_RETURN (score_threshold_,
485- GetScoreThreshold (*metadata_extractor,
486- *output_tensors_metadata->Get (2 )));
511+ ASSIGN_OR_RETURN (
512+ score_threshold_,
513+ GetScoreThreshold (*metadata_extractor,
514+ *output_tensors_metadata->Get (output_indices_[2 ])));
515+ }
516+
517+ // Check tensor dimensions and batch size.
518+ for (int i = 0 ; i < 4 ; ++i) {
519+ std::size_t j = output_indices_[i];
520+ const TfLiteTensor* tensor = TfLiteEngine::GetOutput (interpreter, j);
521+ if (tensor->dims ->size != kOutputTensorsExpectedDims [i]) {
522+ return CreateStatusWithPayload (
523+ StatusCode::kInvalidArgument ,
524+ absl::StrFormat (" Output tensor at index %d is expected to "
525+ " have %d dimensions, found %d." ,
526+ j, kOutputTensorsExpectedDims [i], tensor->dims ->size ),
527+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError );
528+ }
529+ if (tensor->dims ->data [0 ] != 1 ) {
530+ return CreateStatusWithPayload (
531+ StatusCode::kInvalidArgument ,
532+ absl::StrFormat (" Expected batch size of 1, found %d." ,
533+ tensor->dims ->data [0 ]),
534+ TfLiteSupportStatus::kInvalidOutputTensorDimensionsError );
535+ }
487536 }
488537
489538 return absl::OkStatus ();
@@ -551,11 +600,11 @@ StatusOr<DetectionResult> ObjectDetector::Postprocess(
551600 // Most of the checks here should never happen, as outputs have been validated
552601 // at construction time. Checking nonetheless and returning internal errors if
553602 // something bad happens.
554- RETURN_IF_ERROR (SanityCheckOutputTensors (output_tensors));
603+ RETURN_IF_ERROR (SanityCheckOutputTensors (output_tensors, output_indices_ ));
555604
556605 // Get number of available results.
557- const int num_results = static_cast <int >(AssertAndReturnTypedTensor< float >(
558- output_tensors[kDefaultNumResultsIndex ])[0 ]);
606+ const int num_results = static_cast <int >(
607+ AssertAndReturnTypedTensor< float >( output_tensors[output_indices_[ 3 ] ])[0 ]);
559608 // Compute number of max results to return.
560609 const int max_results = options_->max_results () > 0
561610 ? std::min (options_->max_results (), num_results)
@@ -569,11 +618,11 @@ StatusOr<DetectionResult> ObjectDetector::Postprocess(
569618 upright_input_frame_dimensions.Swap ();
570619 }
571620 const float * locations =
572- AssertAndReturnTypedTensor<float >(output_tensors[kDefaultLocationsIndex ]);
621+ AssertAndReturnTypedTensor<float >(output_tensors[output_indices_[ 0 ] ]);
573622 const float * classes =
574- AssertAndReturnTypedTensor<float >(output_tensors[kDefaultClassesIndex ]);
623+ AssertAndReturnTypedTensor<float >(output_tensors[output_indices_[ 1 ] ]);
575624 const float * scores =
576- AssertAndReturnTypedTensor<float >(output_tensors[kDefaultScoresIndex ]);
625+ AssertAndReturnTypedTensor<float >(output_tensors[output_indices_[ 2 ] ]);
577626 DetectionResult results;
578627 for (int i = 0 ; i < num_results; ++i) {
579628 const int class_index = static_cast <int >(classes[i]);
0 commit comments