Skip to content

Commit e4ccd75

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Refactor Prepocessor/Postprocessor
Major change: 1. Created a new base class, Processor, to let Prepocessor/Postprocessor share common logic of creation, sanity check, metadata extracting etc. 2. Created a new factory method to build a subclass of Processor. Simplified Processor creation by encapsulating the constructor and SanityCheck together. 3. Updated method name to follow tflite_support conversion: Tensor() -> GetTensor(), Metadata() -> GetTensorMetadata(). PiperOrigin-RevId: 408735018
1 parent 7988c2b commit e4ccd75

File tree

8 files changed

+179
-165
lines changed

8 files changed

+179
-165
lines changed

tensorflow_lite_support/cc/task/processor/audio_preprocessor.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,10 @@ tflite::support::StatusOr<const AudioProperties*> GetAudioPropertiesSafe(
7474
tflite::support::StatusOr<std::unique_ptr<AudioPreprocessor>>
7575
AudioPreprocessor::Create(tflite::task::core::TfLiteEngine* engine,
7676
const std::initializer_list<int> input_indices) {
77-
RETURN_IF_ERROR(Preprocessor::SanityCheck(/* num_expected_tensors = */ 1,
78-
engine, input_indices,
79-
/* requires_metadata = */ true));
80-
auto processor =
81-
::absl::WrapUnique(new AudioPreprocessor(engine, input_indices));
77+
ASSIGN_OR_RETURN(auto processor,
78+
Processor::Create<AudioPreprocessor>(
79+
/* num_expected_tensors = */ 1, engine, input_indices));
80+
8281
RETURN_IF_ERROR(processor->Init());
8382
return processor;
8483
}
@@ -90,8 +89,9 @@ absl::Status AudioPreprocessor::Init() {
9089
}
9190

9291
absl::Status AudioPreprocessor::SetAudioFormatFromMetadata() {
93-
ASSIGN_OR_RETURN(const AudioProperties* props,
94-
GetAudioPropertiesSafe(Metadata(), input_indices_.at(0)));
92+
ASSIGN_OR_RETURN(
93+
const AudioProperties* props,
94+
GetAudioPropertiesSafe(GetTensorMetadata(), tensor_indices_.at(0)));
9595
audio_format_.channels = props->channels();
9696
audio_format_.sample_rate = props->sample_rate();
9797
if (audio_format_.channels <= 0 || audio_format_.sample_rate <= 0) {
@@ -106,16 +106,16 @@ absl::Status AudioPreprocessor::SetAudioFormatFromMetadata() {
106106

107107
absl::Status AudioPreprocessor::CheckAndSetInputs() {
108108
input_buffer_size_ = 1;
109-
for (int i = 0; i < Tensor()->dims->size; i++) {
110-
if (Tensor()->dims->data[i] < 1) {
109+
for (int i = 0; i < GetTensor()->dims->size; i++) {
110+
if (GetTensor()->dims->data[i] < 1) {
111111
return CreateStatusWithPayload(
112112
absl::StatusCode::kInvalidArgument,
113113
absl::StrFormat("Invalid size: %d for input tensor dimension: %d.",
114-
Tensor()->dims->data[i], i),
114+
GetTensor()->dims->data[i], i),
115115
tflite::support::TfLiteSupportStatus::
116116
kInvalidInputTensorDimensionsError);
117117
}
118-
input_buffer_size_ *= Tensor()->dims->data[i];
118+
input_buffer_size_ *= GetTensor()->dims->data[i];
119119
}
120120
// Check if the input buffer size is divisible by the required audio channels.
121121
// This needs to be done after loading metadata and input.
@@ -158,7 +158,7 @@ absl::Status AudioPreprocessor::Preprocess(
158158
tflite::support::TfLiteSupportStatus::kInvalidArgumentError);
159159
}
160160
return tflite::task::core::PopulateTensor(audio_buffer.GetFloatBuffer(),
161-
input_buffer_size_, Tensor());
161+
input_buffer_size_, GetTensor());
162162
}
163163

164164
} // namespace processor

tensorflow_lite_support/cc/task/processor/classification_postprocessor.cc

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,9 @@ tflite::support::StatusOr<std::unique_ptr<ClassificationPostprocessor>>
4444
ClassificationPostprocessor::Create(
4545
core::TfLiteEngine* engine, const std::initializer_list<int> output_indices,
4646
std::unique_ptr<ClassificationOptions> options) {
47-
RETURN_IF_ERROR(Postprocessor::SanityCheck(/* num_expected_tensors = */ 1,
48-
engine, output_indices));
49-
50-
auto processor =
51-
absl::WrapUnique(new ClassificationPostprocessor(engine, output_indices));
47+
ASSIGN_OR_RETURN(auto processor,
48+
Processor::Create<ClassificationPostprocessor>(
49+
/* num_expected_tensors = */ 1, engine, output_indices));
5250

5351
RETURN_IF_ERROR(processor->Init(std::move(options)));
5452
return processor;
@@ -72,13 +70,13 @@ absl::Status ClassificationPostprocessor::Init(
7270
TfLiteSupportStatus::kInvalidArgumentError);
7371
}
7472

75-
ASSIGN_OR_RETURN(
76-
classification_head_,
77-
BuildClassificationHead(*engine_->metadata_extractor(), *Metadata(),
78-
options->display_names_locale()));
73+
ASSIGN_OR_RETURN(classification_head_,
74+
BuildClassificationHead(*engine_->metadata_extractor(),
75+
*GetTensorMetadata(),
76+
options->display_names_locale()));
7977

8078
// Sanity check output tensors
81-
const TfLiteTensor* output_tensor = Tensor();
79+
const TfLiteTensor* output_tensor = GetTensor();
8280
const int num_dimensions = output_tensor->dims->size;
8381
if (num_dimensions == 4) {
8482
if (output_tensor->dims->data[1] != 1 ||
@@ -87,7 +85,7 @@ absl::Status ClassificationPostprocessor::Init(
8785
StatusCode::kInvalidArgument,
8886
absl::StrFormat("Unexpected WxH sizes for output index %d: got "
8987
"%dx%d, expected 1x1.",
90-
output_indices_.at(0), output_tensor->dims->data[2],
88+
tensor_indices_.at(0), output_tensor->dims->data[2],
9189
output_tensor->dims->data[1]),
9290
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
9391
}
@@ -98,15 +96,15 @@ absl::Status ClassificationPostprocessor::Init(
9896
"Unexpected number of dimensions for output index %d: got %dD, "
9997
"expected either 2D (BxN with B=1) or 4D (BxHxWxN with B=1, W=1, "
10098
"H=1).",
101-
output_indices_.at(0), num_dimensions),
99+
tensor_indices_.at(0), num_dimensions),
102100
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
103101
}
104102
if (output_tensor->dims->data[0] != 1) {
105103
return CreateStatusWithPayload(
106104
StatusCode::kInvalidArgument,
107105
absl::StrFormat("The output array is expected to have a batch size "
108106
"of 1. Got %d for output index %d.",
109-
output_tensor->dims->data[0], output_indices_.at(0)),
107+
output_tensor->dims->data[0], tensor_indices_.at(0)),
110108
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
111109
}
112110
int num_classes = output_tensor->dims->data[num_dimensions - 1];
@@ -126,7 +124,7 @@ absl::Status ClassificationPostprocessor::Init(
126124
absl::StrFormat("Got %d class(es) for output index %d, expected %d "
127125
"according to the label map.",
128126
output_tensor->dims->data[num_dimensions - 1],
129-
output_indices_.at(0), num_label_map_items),
127+
tensor_indices_.at(0), num_label_map_items),
130128
TfLiteSupportStatus::kMetadataInconsistencyError);
131129
}
132130
if (output_tensor->type != kTfLiteUInt8 &&
@@ -156,7 +154,7 @@ absl::Status ClassificationPostprocessor::Init(
156154
if (head_class_names.empty()) {
157155
std::string name = classification_head_.name;
158156
if (name.empty()) {
159-
name = absl::StrFormat("#%d", output_indices_.at(0));
157+
name = absl::StrFormat("#%d", tensor_indices_.at(0));
160158
}
161159
return CreateStatusWithPayload(
162160
StatusCode::kInvalidArgument,

tensorflow_lite_support/cc/task/processor/classification_postprocessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,12 @@ class ClassificationPostprocessor : public Postprocessor {
100100
template <typename T>
101101
absl::Status ClassificationPostprocessor::Postprocess(T* classifications) {
102102
const auto& head = classification_head_;
103-
classifications->set_head_index(output_indices_.at(0));
103+
classifications->set_head_index(tensor_indices_.at(0));
104104

105105
std::vector<std::pair<int, float>> score_pairs;
106106
score_pairs.reserve(head.label_map_items.size());
107107

108-
const TfLiteTensor* output_tensor = Tensor();
108+
const TfLiteTensor* output_tensor = GetTensor();
109109
if (output_tensor->type == kTfLiteUInt8) {
110110
ASSIGN_OR_RETURN(const uint8* output_data,
111111
core::AssertAndReturnTypedTensor<uint8>(output_tensor));

tensorflow_lite_support/cc/task/processor/embedding_postprocessor.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,10 @@ tflite::support::StatusOr<std::unique_ptr<EmbeddingPostprocessor>>
2424
EmbeddingPostprocessor::Create(core::TfLiteEngine* engine,
2525
const std::initializer_list<int> output_indices,
2626
std::unique_ptr<EmbeddingOptions> options) {
27-
RETURN_IF_ERROR(Postprocessor::SanityCheck(/* num_expected_tensors = */ 1,
28-
engine, output_indices,
29-
/* requires_metadata = */ false));
30-
31-
auto processor =
32-
absl::WrapUnique(new EmbeddingPostprocessor(engine, output_indices));
27+
ASSIGN_OR_RETURN(auto processor,
28+
Processor::Create<EmbeddingPostprocessor>(
29+
/* num_expected_tensors = */ 1, engine, output_indices,
30+
/* requires_metadata = */ false));
3331

3432
RETURN_IF_ERROR(processor->Init(std::move(options)));
3533
return processor;
@@ -39,8 +37,8 @@ absl::Status EmbeddingPostprocessor::Init(
3937
std::unique_ptr<EmbeddingOptions> options) {
4038
options_ = std::move(options);
4139

42-
int output_index = output_indices_.at(0);
43-
auto* output_tensor = Tensor();
40+
int output_index = tensor_indices_.at(0);
41+
auto* output_tensor = GetTensor();
4442
int num_dimensions = output_tensor->dims->size;
4543
if (num_dimensions == 4) {
4644
if (output_tensor->dims->data[1] != 1 ||

tensorflow_lite_support/cc/task/processor/embedding_postprocessor.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,15 @@ class EmbeddingPostprocessor : public Postprocessor {
7878

7979
template <typename T>
8080
absl::Status EmbeddingPostprocessor::Postprocess(T* embedding) {
81-
embedding->set_output_index(output_indices_.at(0));
81+
embedding->set_output_index(tensor_indices_.at(0));
8282
auto* feature_vector = embedding->mutable_feature_vector();
83-
if (Tensor()->type == kTfLiteUInt8) {
83+
if (GetTensor()->type == kTfLiteUInt8) {
8484
const uint8* output_data =
8585
engine_->interpreter()->typed_output_tensor<uint8>(
86-
output_indices_.at(0));
86+
tensor_indices_.at(0));
8787
// Get the zero_point and scale parameters from the tensor metadata.
8888
const int output_tensor_index =
89-
engine_->interpreter()->outputs()[output_indices_.at(0)];
89+
engine_->interpreter()->outputs()[tensor_indices_.at(0)];
9090
const TfLiteTensor* output_tensor =
9191
engine_->interpreter()->tensor(output_tensor_index);
9292
for (int j = 0; j < embedding_dimension_; ++j) {
@@ -98,7 +98,7 @@ absl::Status EmbeddingPostprocessor::Postprocess(T* embedding) {
9898
// Float
9999
const float* output_data =
100100
engine_->interpreter()->typed_output_tensor<float>(
101-
output_indices_.at(0));
101+
tensor_indices_.at(0));
102102
for (int j = 0; j < embedding_dimension_; ++j) {
103103
feature_vector->add_value_float(output_data[j]);
104104
}

tensorflow_lite_support/cc/task/processor/image_preprocessor.cc

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,10 @@ tflite::support::StatusOr<std::unique_ptr<ImagePreprocessor>>
3838
ImagePreprocessor::Create(
3939
core::TfLiteEngine* engine, const std::initializer_list<int> input_indices,
4040
const vision::FrameBufferUtils::ProcessEngine& process_engine) {
41-
RETURN_IF_ERROR(Preprocessor::SanityCheck(/* num_expected_tensors = */ 1,
42-
engine, input_indices,
43-
/* requires_metadata = */ false));
44-
45-
auto processor =
46-
absl::WrapUnique(new ImagePreprocessor(engine, input_indices));
41+
ASSIGN_OR_RETURN(auto processor,
42+
Processor::Create<ImagePreprocessor>(
43+
/* num_expected_tensors = */ 1, engine, input_indices,
44+
/* requires_metadata = */ false));
4745

4846
RETURN_IF_ERROR(processor->Init(process_engine));
4947
return processor;
@@ -85,7 +83,7 @@ absl::Status ImagePreprocessor::Init(
8583
}
8684

8785
// Determine if the input shape is resizable.
88-
const TfLiteIntArray* dims_signature = Tensor()->dims_signature;
86+
const TfLiteIntArray* dims_signature = GetTensor()->dims_signature;
8987

9088
// Some fixed-shape models do not have dims_signature.
9189
if (dims_signature != nullptr && dims_signature->size > 2) {
@@ -152,26 +150,26 @@ absl::Status ImagePreprocessor::Preprocess(const FrameBuffer& frame_buffer,
152150
// If dynamic, it will re-dim the entire graph as per the input.
153151
if (is_height_mutable_ || is_width_mutable_) {
154152
engine_->interpreter()->ResizeInputTensorStrict(
155-
0, {Tensor()->dims->data[0], input_specs_.image_height,
156-
input_specs_.image_width, Tensor()->dims->data[3]});
153+
0, {GetTensor()->dims->data[0], input_specs_.image_height,
154+
input_specs_.image_width, GetTensor()->dims->data[3]});
157155

158156
engine_->interpreter()->AllocateTensors();
159157
}
160158
// Then normalize pixel data (if needed) and populate the input tensor.
161159
switch (input_specs_.tensor_type) {
162160
case kTfLiteUInt8:
163-
if (Tensor()->bytes != input_data_byte_size) {
161+
if (GetTensor()->bytes != input_data_byte_size) {
164162
return tflite::support::CreateStatusWithPayload(
165163
absl::StatusCode::kInternal,
166164
"Size mismatch or unsupported padding bytes between pixel data "
167165
"and input tensor.");
168166
}
169167
// No normalization required: directly populate data.
170168
RETURN_IF_ERROR(tflite::task::core::PopulateTensor(
171-
input_data, input_data_byte_size / sizeof(uint8), Tensor()));
169+
input_data, input_data_byte_size / sizeof(uint8), GetTensor()));
172170
break;
173171
case kTfLiteFloat32: {
174-
if (Tensor()->bytes / sizeof(float) !=
172+
if (GetTensor()->bytes / sizeof(float) !=
175173
input_data_byte_size / sizeof(uint8)) {
176174
return tflite::support::CreateStatusWithPayload(
177175
absl::StatusCode::kInternal,
@@ -181,7 +179,7 @@ absl::Status ImagePreprocessor::Preprocess(const FrameBuffer& frame_buffer,
181179
// Normalize and populate.
182180
ASSIGN_OR_RETURN(
183181
float* normalized_input_data,
184-
tflite::task::core::AssertAndReturnTypedTensor<float>(Tensor()));
182+
tflite::task::core::AssertAndReturnTypedTensor<float>(GetTensor()));
185183
const tflite::task::vision::NormalizationOptions& normalization_options =
186184
input_specs_.normalization_options.value();
187185
for (int i = 0; i < normalization_options.num_values; i++) {

tensorflow_lite_support/cc/task/processor/processor.cc

Lines changed: 18 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -18,72 +18,35 @@ namespace tflite {
1818
namespace task {
1919
namespace processor {
2020

21-
/* static */
22-
absl::Status Preprocessor::SanityCheck(
23-
int num_expected_tensors, core::TfLiteEngine* engine,
24-
const std::initializer_list<int> input_indices, bool requires_metadata) {
25-
if (input_indices.size() != num_expected_tensors) {
26-
return support::CreateStatusWithPayload(
27-
absl::StatusCode::kInvalidArgument,
28-
absl::StrFormat("Preprocessor can handle %d tensors, "
29-
"got: %d tensors.",
30-
num_expected_tensors, input_indices.size()));
31-
}
32-
for (auto* p = input_indices.begin(); p < input_indices.end(); p++) {
33-
int input_index = *p;
34-
if (input_index < 0 ||
35-
input_index >= engine->InputCount(engine->interpreter())) {
36-
return support::CreateStatusWithPayload(
37-
absl::StatusCode::kInvalidArgument,
38-
absl::StrFormat(
39-
"Invalid input_index: %d. Model has %d input tensors.",
40-
input_index, engine->InputCount(engine->interpreter())));
41-
}
42-
if (requires_metadata) {
43-
auto* metadata =
44-
engine->metadata_extractor()->GetInputTensorMetadata(input_index);
45-
if (metadata == nullptr) {
46-
return CreateStatusWithPayload(
47-
absl::StatusCode::kInvalidArgument,
48-
absl::StrFormat("Input tensor %d is missing TensorMetadata.",
49-
input_index),
50-
support::TfLiteSupportStatus::kMetadataNotFoundError);
51-
}
52-
}
53-
}
21+
constexpr char Preprocessor::kInputTypeName[];
22+
constexpr char Postprocessor::kOutputTypeName[];
5423

55-
return absl::OkStatus();
56-
}
57-
58-
/* static */
59-
absl::Status Postprocessor::SanityCheck(
60-
int num_expected_tensors, core::TfLiteEngine* engine,
61-
const std::initializer_list<int> output_indices, bool requires_metadata) {
62-
if (output_indices.size() != num_expected_tensors) {
24+
absl::Status Processor::SanityCheck(int num_expected_tensors,
25+
bool requires_metadata) {
26+
const char* tensor_type_name = GetTensorTypeName();
27+
if (tensor_indices_.size() != num_expected_tensors) {
6328
return support::CreateStatusWithPayload(
6429
absl::StatusCode::kInvalidArgument,
65-
absl::StrFormat("Postprocessor can handle %d tensors, "
30+
absl::StrFormat("Processor can handle %d tensors, "
6631
"got: %d tensors.",
67-
num_expected_tensors, output_indices.size()));
32+
num_expected_tensors, tensor_indices_.size()));
6833
}
69-
for (auto* p = output_indices.begin(); p < output_indices.end(); p++) {
70-
int output_index = *p;
71-
if (output_index < 0 ||
72-
output_index >= engine->OutputCount(engine->interpreter())) {
34+
35+
int tensor_count = GetModelTensorCount();
36+
for (int i = 0; i < tensor_indices_.size(); i++) {
37+
int index = tensor_indices_.at(i);
38+
if (index < 0 || index >= tensor_count) {
7339
return support::CreateStatusWithPayload(
7440
absl::StatusCode::kInvalidArgument,
75-
absl::StrFormat(
76-
"Invalid output_index: %d. Model has %d output tensors.",
77-
output_index, engine->OutputCount(engine->interpreter())));
41+
absl::StrFormat("Invalid tensor_index: %d. Model has %d %s tensors.",
42+
index, tensor_count, tensor_type_name));
7843
}
7944
if (requires_metadata) {
80-
auto* metadata =
81-
engine->metadata_extractor()->GetOutputTensorMetadata(output_index);
82-
if (metadata == nullptr) {
45+
if (GetTensorMetadata(i) == nullptr) {
8346
return CreateStatusWithPayload(
8447
absl::StatusCode::kInvalidArgument,
85-
absl::StrFormat("Output tensor %d is missing TensorMetadata.",
86-
output_index),
48+
absl::StrFormat("%s tensor %d is missing TensorMetadata.",
49+
tensor_type_name, index),
8750
support::TfLiteSupportStatus::kMetadataNotFoundError);
8851
}
8952
}

0 commit comments

Comments
 (0)