Skip to content

Commit 84ca38b

Browse files
committed
Return type of CreateImageClassifierCppOptionsFromCOptions changed to StatusOr<ImageClassifierOptionsCpp>
1 parent dc35aae commit 84ca38b

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

tensorflow_lite_support/c/task/vision/image_classifier.cc

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,28 @@ using ImageClassifierCpp = ::tflite::task::vision::ImageClassifier;
3333
using ImageClassifierOptionsCpp =
3434
::tflite::task::vision::ImageClassifierOptions;
3535
using FrameBufferCpp = ::tflite::task::vision::FrameBuffer;
36+
using ::tflite::support::TfLiteSupportStatus;
3637

37-
std::unique_ptr<ImageClassifierOptionsCpp>
38-
CreateImageClassifierCppOptionsFromCOptions(
38+
StatusOr<ImageClassifierOptionsCpp> CreateImageClassifierCppOptionsFromCOptions(
3939
const TfLiteImageClassifierOptions* c_options) {
40-
std::unique_ptr<ImageClassifierOptionsCpp> cpp_options(
41-
new ImageClassifierOptionsCpp);
40+
if (c_options == nullptr) {
41+
return CreateStatusWithPayload(
42+
absl::StatusCode::kInvalidArgument,
43+
absl::StrFormat("Expected non null options."),
44+
TfLiteSupportStatus::kInvalidArgumentError);
45+
}
46+
47+
ImageClassifierOptionsCpp cpp_options = {};
4248

4349
// More file sources can be added in else ifs
4450
if (c_options->base_options.model_file.file_path)
45-
cpp_options->mutable_base_options()->mutable_model_file()->set_file_name(
51+
cpp_options.mutable_base_options()->mutable_model_file()->set_file_name(
4652
c_options->base_options.model_file.file_path);
4753

4854
// c_options->base_options.compute_settings.num_threads is expected to be
4955
// set to value > 0 or -1. Otherwise invoking
5056
// ImageClassifierCpp::CreateFromOptions() results in a not ok status.
51-
cpp_options->mutable_base_options()
57+
cpp_options.mutable_base_options()
5258
->mutable_compute_settings()
5359
->mutable_tflite_settings()
5460
->mutable_cpu_settings()
@@ -57,27 +63,27 @@ CreateImageClassifierCppOptionsFromCOptions(
5763

5864
for (int i = 0; i < c_options->classification_options.label_denylist.length;
5965
i++)
60-
cpp_options->add_class_name_blacklist(
66+
cpp_options.add_class_name_blacklist(
6167
c_options->classification_options.label_denylist.list[i]);
6268

6369
for (int i = 0; i < c_options->classification_options.label_allowlist.length;
6470
i++)
65-
cpp_options->add_class_name_whitelist(
71+
cpp_options.add_class_name_whitelist(
6672
c_options->classification_options.label_allowlist.list[i]);
6773

6874
// Check needed since setting a nullptr for this field results in a segfault
6975
// on invocation of ImageClassifierCpp::CreateFromOptions().
7076
if (c_options->classification_options.display_names_local) {
71-
cpp_options->set_display_names_locale(
77+
cpp_options.set_display_names_locale(
7278
c_options->classification_options.display_names_local);
7379
}
7480

7581
// c_options->classification_options.max_results is expected to be set to -1
7682
// or any value > 0. Otherwise invoking
7783
// ImageClassifierCpp::CreateFromOptions() results in a not ok status.
78-
cpp_options->set_max_results(c_options->classification_options.max_results);
84+
cpp_options.set_max_results(c_options->classification_options.max_results);
7985

80-
cpp_options->set_score_threshold(
86+
cpp_options.set_score_threshold(
8187
c_options->classification_options.score_threshold);
8288

8389
return cpp_options;
@@ -103,22 +109,17 @@ TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate() {
103109

104110
TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
105111
const TfLiteImageClassifierOptions* options, TfLiteSupportError** error) {
106-
if (options == nullptr) {
107-
tflite::support::CreateTfLiteSupportError(
108-
kInvalidArgumentError, "Expected non null options.", error);
109-
return nullptr;
110-
}
111-
112-
std::unique_ptr<ImageClassifierOptionsCpp> cpp_options =
112+
StatusOr<ImageClassifierOptionsCpp> cpp_option_status =
113113
CreateImageClassifierCppOptionsFromCOptions(options);
114114

115-
if (cpp_options == nullptr) {
116-
tflite::support::CreateTfLiteSupportError(
117-
kError, "Some error occured.", error);
115+
if (!cpp_option_status.ok()) {
116+
::tflite::support::CreateTfLiteSupportErrorWithStatus(
117+
cpp_option_status.status(), error);
118118
return nullptr;
119119
}
120120

121-
auto classifier_status = ImageClassifierCpp::CreateFromOptions(*cpp_options);
121+
StatusOr<std::unique_ptr<ImageClassifierCpp>> classifier_status =
122+
ImageClassifierCpp::CreateFromOptions(cpp_option_status.value());
122123

123124
if (classifier_status.ok()) {
124125
return new TfLiteImageClassifier{.impl =
@@ -194,7 +195,7 @@ TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
194195
tflite::support::CreateTfLiteSupportErrorWithStatus(
195196
cpp_frame_buffer_status.status(), error);
196197
return nullptr;
197-
}
198+
}
198199

199200
BoundingBoxCpp cc_roi;
200201
if (roi == nullptr) {
@@ -206,7 +207,7 @@ TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
206207
cc_roi.set_width(roi->width);
207208
cc_roi.set_height(roi->height);
208209
}
209-
210+
210211
// fnc_sample(cpp_frame_buffer_status);
211212
StatusOr<ClassificationResultCpp> cpp_classification_result_status =
212213
classifier->impl->Classify(*std::move(cpp_frame_buffer_status.value()),
@@ -218,7 +219,8 @@ TfLiteClassificationResult* TfLiteImageClassifierClassifyWithRoi(
218219
return nullptr;
219220
}
220221

221-
return GetClassificationResultCStruct(cpp_classification_result_status.value());
222+
return GetClassificationResultCStruct(
223+
cpp_classification_result_status.value());
222224
}
223225

224226
TfLiteClassificationResult* TfLiteImageClassifierClassify(

0 commit comments

Comments
 (0)