Skip to content

Commit 93a0409

Browse files
ziyeqinghantflite-support-robot
authored andcommitted
Object Detection Task Library: Map the output by the tensor name in metadata.
PiperOrigin-RevId: 399841823
1 parent b5cc57c commit 93a0409

File tree

7 files changed

+141
-63
lines changed

7 files changed

+141
-63
lines changed

tensorflow_lite_support/cc/task/core/task_utils.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ std::string LoadBinaryContent(const char* filename) {
6161
return buffer;
6262
}
6363

64+
int FindIndexByMetadataTensorName(
65+
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
66+
tensor_metadatas,
67+
const std::string& name) {
68+
if (tensor_metadatas == nullptr) {
69+
return -1;
70+
}
71+
for (int i = 0; i < tensor_metadatas->size(); i++) {
72+
if (strcmp(name.data(), tensor_metadatas->Get(i)->name()->c_str()) == 0) {
73+
return i;
74+
}
75+
}
76+
// Returns -1 if not found.
77+
return -1;
78+
}
79+
6480
} // namespace core
6581
} // namespace task
6682
} // namespace tflite

tensorflow_lite_support/cc/task/core/task_utils.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,14 @@ std::string GetStringAtIndex(const TfLiteTensor* labels, int index);
156156
// Loads binary content of a file into a string.
157157
std::string LoadBinaryContent(const char* filename);
158158

159+
// Gets the index from a vector of tensors with name specified inside metadata.
160+
// The range of the return value should be [0, output_tensor_size). If not
161+
// found, returns -1.
162+
int FindIndexByMetadataTensorName(
163+
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
164+
tensor_metadatas,
165+
const std::string& name);
166+
159167
// Gets the tensor from a vector of tensors with name specified inside metadata.
160168
template <typename TensorType>
161169
static TensorType* FindTensorByName(
@@ -167,12 +175,8 @@ static TensorType* FindTensorByName(
167175
tensor_metadatas->size() != tensors.size()) {
168176
return nullptr;
169177
}
170-
for (flatbuffers::uoffset_t i = 0; i < tensor_metadatas->size(); i++) {
171-
if (strcmp(name.data(), tensor_metadatas->Get(i)->name()->c_str()) == 0) {
172-
return tensors[i];
173-
}
174-
}
175-
return nullptr;
178+
int i = FindIndexByMetadataTensorName(tensor_metadatas, name);
179+
return i == -1 ? nullptr : tensors[i];
176180
}
177181

178182
} // namespace core

tensorflow_lite_support/cc/task/vision/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ cc_library_with_tflite(
4444
"@com_google_absl//absl/status",
4545
"@com_google_absl//absl/strings",
4646
"@com_google_absl//absl/strings:str_format",
47+
"@com_google_glog//:glog",
4748
"@org_tensorflow//tensorflow/lite/c:common",
4849
"@org_tensorflow//tensorflow/lite/core/api",
4950
],

tensorflow_lite_support/cc/task/vision/object_detector.cc

Lines changed: 106 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <limits>
2020
#include <vector>
2121

22+
#include <glog/logging.h>
2223
#include "external/com_google_absl/absl/memory/memory.h"
2324
#include "external/com_google_absl/absl/status/status.h"
2425
#include "external/com_google_absl/absl/strings/str_format.h"
@@ -56,11 +57,15 @@ using ::tflite::support::CreateStatusWithPayload;
5657
using ::tflite::support::StatusOr;
5758
using ::tflite::support::TfLiteSupportStatus;
5859
using ::tflite::task::core::AssertAndReturnTypedTensor;
60+
using ::tflite::task::core::FindIndexByMetadataTensorName;
5961
using ::tflite::task::core::TaskAPIFactory;
6062
using ::tflite::task::core::TfLiteEngine;
6163

6264
// The expected number of dimensions of the 4 output tensors, representing in
63-
// that order: locations, classes, scores, num_results.
65+
// that order: locations, categories, scores, num_results. The order is
66+
// coming from the TFLite custom NMS op for object detection post-processing
67+
// shown in
68+
// https://github.com/tensorflow/tensorflow/blob/1c419b231b622bd9e9685682545e9064f0fbb42a/tensorflow/lite/kernels/detection_postprocess.cc#L47.
6469
static constexpr int kOutputTensorsExpectedDims[4] = {3, 2, 2, 1};
6570
constexpr int kDefaultLocationsIndex = 0;
6671
constexpr int kDefaultClassesIndex = 1;
@@ -69,6 +74,11 @@ constexpr int kDefaultNumResultsIndex = 3;
6974

7075
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
7176

77+
constexpr char kLocationTensorName[] = "location";
78+
constexpr char kCategoryTensorName[] = "category";
79+
constexpr char kScoreTensorName[] = "score";
80+
constexpr char kNumberOfDetectionsTensorName[] = "number of detections";
81+
7282
StatusOr<const BoundingBoxProperties*> GetBoundingBoxProperties(
7383
const TensorMetadata& tensor_metadata) {
7484
if (tensor_metadata.content() == nullptr ||
@@ -166,8 +176,38 @@ StatusOr<float> GetScoreThreshold(
166176
->global_score_threshold();
167177
}
168178

179+
// Use tensor names in metadata to get the output order.
180+
std::vector<int> GetOutputIndices(
181+
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
182+
tensor_metadatas) {
183+
std::vector<int> output_indices = {
184+
FindIndexByMetadataTensorName(tensor_metadatas, kLocationTensorName),
185+
FindIndexByMetadataTensorName(tensor_metadatas, kCategoryTensorName),
186+
FindIndexByMetadataTensorName(tensor_metadatas, kScoreTensorName),
187+
FindIndexByMetadataTensorName(tensor_metadatas,
188+
kNumberOfDetectionsTensorName)};
189+
190+
for (int i = 0; i < 4; i++) {
191+
int output_index = output_indices[i];
192+
// If tensor name is not found, set the default output indices.
193+
if (output_index == -1) {
194+
LOG(WARNING) << absl::StrFormat(
195+
"You don't seem to be matching tensor names in metadata list. The "
196+
"tensor name \"%s\" at index %d in the model metadata doesn't match "
197+
"the available output names: [\"%s\", \"%s\", \"%s\", \"%s\"].",
198+
tensor_metadatas->Get(i)->name()->c_str(), i, kLocationTensorName,
199+
kCategoryTensorName, kScoreTensorName, kNumberOfDetectionsTensorName);
200+
output_indices = {kDefaultLocationsIndex, kDefaultClassesIndex,
201+
kDefaultScoresIndex, kDefaultNumResultsIndex};
202+
return output_indices;
203+
}
204+
}
205+
return output_indices;
206+
}
207+
169208
absl::Status SanityCheckOutputTensors(
170-
const std::vector<const TfLiteTensor*>& output_tensors) {
209+
const std::vector<const TfLiteTensor*>& output_tensors,
210+
const std::vector<int>& output_indices) {
171211
if (output_tensors.size() != 4) {
172212
return CreateStatusWithPayload(
173213
StatusCode::kInternal,
@@ -176,51 +216,53 @@ absl::Status SanityCheckOutputTensors(
176216
}
177217

178218
// Get number of results.
179-
if (output_tensors[kDefaultNumResultsIndex]->dims->data[0] != 1) {
219+
const TfLiteTensor* num_results_tensor = output_tensors[output_indices[3]];
220+
if (num_results_tensor->dims->data[0] != 1) {
180221
return CreateStatusWithPayload(
181222
StatusCode::kInternal,
182223
absl::StrFormat(
183224
"Expected tensor with dimensions [1] at index 3, found [%d]",
184-
output_tensors[kDefaultNumResultsIndex]->dims->data[0]));
225+
num_results_tensor->dims->data[0]));
185226
}
186-
int num_results = static_cast<int>(AssertAndReturnTypedTensor<float>(
187-
output_tensors[kDefaultNumResultsIndex])[0]);
227+
int num_results = static_cast<int>(
228+
AssertAndReturnTypedTensor<float>(num_results_tensor)[0]);
188229

230+
const TfLiteTensor* location_tensor = output_tensors[output_indices[0]];
189231
// Check dimensions for the other tensors are correct.
190-
if (output_tensors[kDefaultLocationsIndex]->dims->data[0] != 1 ||
191-
output_tensors[kDefaultLocationsIndex]->dims->data[1] < num_results ||
192-
output_tensors[kDefaultLocationsIndex]->dims->data[2] != 4) {
232+
if (location_tensor->dims->data[0] != 1 ||
233+
location_tensor->dims->data[1] < num_results ||
234+
location_tensor->dims->data[2] != 4) {
193235
return CreateStatusWithPayload(
194236
StatusCode::kInternal,
195237
absl::StrFormat(
196238
"Expected locations tensor with dimensions [1, num_detected_boxes, "
197-
"4] at index %d, num_detected_boxes >= %d, found [%d,%d,%d].",
198-
kDefaultLocationsIndex, num_results,
199-
output_tensors[kDefaultLocationsIndex]->dims->data[0],
200-
output_tensors[kDefaultLocationsIndex]->dims->data[1],
201-
output_tensors[kDefaultLocationsIndex]->dims->data[2]));
202-
}
203-
if (output_tensors[kDefaultClassesIndex]->dims->data[0] != 1 ||
204-
output_tensors[kDefaultClassesIndex]->dims->data[1] < num_results) {
239+
"4] at index 0, num_detected_boxes >= %d, found [%d,%d,%d].",
240+
num_results, location_tensor->dims->data[0],
241+
location_tensor->dims->data[1], location_tensor->dims->data[2]));
242+
}
243+
244+
const TfLiteTensor* class_tensor = output_tensors[output_indices[1]];
245+
if (class_tensor->dims->data[0] != 1 ||
246+
class_tensor->dims->data[1] < num_results) {
205247
return CreateStatusWithPayload(
206248
StatusCode::kInternal,
207249
absl::StrFormat(
208250
"Expected classes tensor with dimensions [1, num_detected_boxes] "
209-
"at index %d, num_detected_boxes >= %d, found [%d,%d].",
210-
kDefaultClassesIndex, num_results,
211-
output_tensors[kDefaultClassesIndex]->dims->data[0],
212-
output_tensors[kDefaultClassesIndex]->dims->data[1]));
251+
"at index 1, num_detected_boxes >= %d, found [%d,%d].",
252+
num_results, class_tensor->dims->data[0],
253+
class_tensor->dims->data[1]));
213254
}
214-
if (output_tensors[kDefaultScoresIndex]->dims->data[0] != 1 ||
215-
output_tensors[kDefaultScoresIndex]->dims->data[1] < num_results) {
255+
256+
const TfLiteTensor* scores_tensor = output_tensors[output_indices[2]];
257+
if (scores_tensor->dims->data[0] != 1 ||
258+
scores_tensor->dims->data[1] < num_results) {
216259
return CreateStatusWithPayload(
217260
StatusCode::kInternal,
218261
absl::StrFormat(
219262
"Expected scores tensor with dimensions [1, num_detected_boxes] "
220-
"at index %d, num_detected_boxes >= %d, found [%d,%d].",
221-
kDefaultScoresIndex, num_results,
222-
output_tensors[kDefaultScoresIndex]->dims->data[0],
223-
output_tensors[kDefaultScoresIndex]->dims->data[1]));
263+
"at index 2, num_detected_boxes >= %d, found [%d,%d].",
264+
num_results, scores_tensor->dims->data[0],
265+
scores_tensor->dims->data[1]));
224266
}
225267

226268
return absl::OkStatus();
@@ -409,25 +451,6 @@ absl::Status ObjectDetector::CheckAndSetOutputs() {
409451
TfLiteEngine::OutputCount(interpreter)),
410452
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
411453
}
412-
// Check tensor dimensions and batch size.
413-
for (int i = 0; i < 4; ++i) {
414-
const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, i);
415-
if (tensor->dims->size != kOutputTensorsExpectedDims[i]) {
416-
return CreateStatusWithPayload(
417-
StatusCode::kInvalidArgument,
418-
absl::StrFormat("Output tensor at index %d is expected to "
419-
"have %d dimensions, found %d.",
420-
i, kOutputTensorsExpectedDims[i], tensor->dims->size),
421-
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
422-
}
423-
if (tensor->dims->data[0] != 1) {
424-
return CreateStatusWithPayload(
425-
StatusCode::kInvalidArgument,
426-
absl::StrFormat("Expected batch size of 1, found %d.",
427-
tensor->dims->data[0]),
428-
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
429-
}
430-
}
431454

432455
// Now, perform sanity checks and extract metadata.
433456
const ModelMetadataExtractor* metadata_extractor =
@@ -455,10 +478,13 @@ absl::Status ObjectDetector::CheckAndSetOutputs() {
455478
TfLiteSupportStatus::kMetadataInconsistencyError);
456479
}
457480

481+
output_indices_ = GetOutputIndices(output_tensors_metadata);
482+
458483
// Extract mandatory BoundingBoxProperties for easier access at
459484
// post-processing time, performing sanity checks on the fly.
460485
ASSIGN_OR_RETURN(const BoundingBoxProperties* bounding_box_properties,
461-
GetBoundingBoxProperties(*output_tensors_metadata->Get(0)));
486+
GetBoundingBoxProperties(
487+
*output_tensors_metadata->Get(output_indices_[0])));
462488
if (bounding_box_properties->index() == nullptr) {
463489
bounding_box_corners_order_ = {0, 1, 2, 3};
464490
} else {
@@ -474,16 +500,39 @@ absl::Status ObjectDetector::CheckAndSetOutputs() {
474500
// Build label map (if available) from metadata.
475501
ASSIGN_OR_RETURN(
476502
label_map_,
477-
GetLabelMapIfAny(*metadata_extractor, *output_tensors_metadata->Get(1),
503+
GetLabelMapIfAny(*metadata_extractor,
504+
*output_tensors_metadata->Get(output_indices_[1]),
478505
options_->display_names_locale()));
479506

480507
// Set score threshold.
481508
if (options_->has_score_threshold()) {
482509
score_threshold_ = options_->score_threshold();
483510
} else {
484-
ASSIGN_OR_RETURN(score_threshold_,
485-
GetScoreThreshold(*metadata_extractor,
486-
*output_tensors_metadata->Get(2)));
511+
ASSIGN_OR_RETURN(
512+
score_threshold_,
513+
GetScoreThreshold(*metadata_extractor,
514+
*output_tensors_metadata->Get(output_indices_[2])));
515+
}
516+
517+
// Check tensor dimensions and batch size.
518+
for (int i = 0; i < 4; ++i) {
519+
std::size_t j = output_indices_[i];
520+
const TfLiteTensor* tensor = TfLiteEngine::GetOutput(interpreter, j);
521+
if (tensor->dims->size != kOutputTensorsExpectedDims[i]) {
522+
return CreateStatusWithPayload(
523+
StatusCode::kInvalidArgument,
524+
absl::StrFormat("Output tensor at index %d is expected to "
525+
"have %d dimensions, found %d.",
526+
j, kOutputTensorsExpectedDims[i], tensor->dims->size),
527+
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
528+
}
529+
if (tensor->dims->data[0] != 1) {
530+
return CreateStatusWithPayload(
531+
StatusCode::kInvalidArgument,
532+
absl::StrFormat("Expected batch size of 1, found %d.",
533+
tensor->dims->data[0]),
534+
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
535+
}
487536
}
488537

489538
return absl::OkStatus();
@@ -551,11 +600,11 @@ StatusOr<DetectionResult> ObjectDetector::Postprocess(
551600
// Most of the checks here should never happen, as outputs have been validated
552601
// at construction time. Checking nonetheless and returning internal errors if
553602
// something bad happens.
554-
RETURN_IF_ERROR(SanityCheckOutputTensors(output_tensors));
603+
RETURN_IF_ERROR(SanityCheckOutputTensors(output_tensors, output_indices_));
555604

556605
// Get number of available results.
557-
const int num_results = static_cast<int>(AssertAndReturnTypedTensor<float>(
558-
output_tensors[kDefaultNumResultsIndex])[0]);
606+
const int num_results = static_cast<int>(
607+
AssertAndReturnTypedTensor<float>(output_tensors[output_indices_[3]])[0]);
559608
// Compute number of max results to return.
560609
const int max_results = options_->max_results() > 0
561610
? std::min(options_->max_results(), num_results)
@@ -569,11 +618,11 @@ StatusOr<DetectionResult> ObjectDetector::Postprocess(
569618
upright_input_frame_dimensions.Swap();
570619
}
571620
const float* locations =
572-
AssertAndReturnTypedTensor<float>(output_tensors[kDefaultLocationsIndex]);
621+
AssertAndReturnTypedTensor<float>(output_tensors[output_indices_[0]]);
573622
const float* classes =
574-
AssertAndReturnTypedTensor<float>(output_tensors[kDefaultClassesIndex]);
623+
AssertAndReturnTypedTensor<float>(output_tensors[output_indices_[1]]);
575624
const float* scores =
576-
AssertAndReturnTypedTensor<float>(output_tensors[kDefaultScoresIndex]);
625+
AssertAndReturnTypedTensor<float>(output_tensors[output_indices_[2]]);
577626
DetectionResult results;
578627
for (int i = 0; i < num_results; ++i) {
579628
const int class_index = static_cast<int>(classes[i]);

tensorflow_lite_support/cc/task/vision/object_detector.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ class ObjectDetector : public BaseVisionTaskApi<DetectionResult> {
190190
// List of score calibration parameters, if any. Built from TFLite Model
191191
// Metadata.
192192
std::unique_ptr<ScoreCalibration> score_calibration_;
193+
194+
// Indices of the output tensors to match the output tensors to the correct
195+
// index order of the output tensors: [location, categories, scores,
196+
// num_detections].
197+
std::vector<int> output_indices_;
193198
};
194199

195200
} // namespace vision

tensorflow_lite_support/cc/test/task/vision/object_detector_test.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ constexpr char kExpectResults[] =
9595
)pb";
9696
constexpr char kMobileSsdWithMetadataDummyScoreCalibration[] =
9797
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_score_calibration.tflite";
98+
// The model has different output tensor order.
99+
constexpr char kEfficientDetWithMetadata[] =
100+
"coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite";
98101

99102
StatusOr<ImageData> LoadImage(std::string image_name) {
100103
return DecodeImageFromFile(JoinPath("./" /*test src dir*/,

0 commit comments

Comments
 (0)