Skip to content

Commit 0d52f32

Browse files
Merge pull request #696 from khanhlvg:image-classifier-c-api
PiperOrigin-RevId: 398347514
2 parents 504c8f4 + 0ad9554 commit 0d52f32

File tree

4 files changed

+94
-80
lines changed

4 files changed

+94
-80
lines changed

tensorflow_lite_support/c/task/processor/classification_options.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,16 @@ typedef struct TfLiteStringArrayOption {
3434
} TfLiteStringArrayOption;
3535

3636
// Holds settings for any single classification task.
37-
// TODO(prianka): change `white/blacklist` to `allow/denylist`
3837
typedef struct TfLiteClassificationOptions {
39-
// Optional blacklist of class names. If non NULL, classifications whose
40-
// class name is in this set will be filtered out. Duplicate or unknown
41-
// class names are ignored. Mutually exclusive with class_name_whitelist.
42-
TfLiteStringArrayOption class_name_blacklist;
43-
44-
// Optional whitelist of class names. If non-empty, classifications whose
45-
// class name is not in this set will be filtered out. Duplicate or unknown
46-
// class names are ignored. Mutually exclusive with class_name_blacklist.
47-
TfLiteStringArrayOption class_name_whitelist;
38+
// Optional denylist of class labels. If non NULL, classifications whose
39+
// class label is in this set will be filtered out. Duplicate or unknown
40+
// class labels are ignored. Mutually exclusive with label_allowlist.
41+
TfLiteStringArrayOption label_denylist;
42+
43+
// Optional allowlist of class labels. If non-empty, classifications whose
44+
// class label is not in this set will be filtered out. Duplicate or unknown
45+
// class labels are ignored. Mutually exclusive with label_denylist.
46+
TfLiteStringArrayOption label_allowlist;
4847

4948
const char* display_names_local;
5049
int max_results;

tensorflow_lite_support/c/task/vision/image_classifier.cc

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ struct TfLiteImageClassifier {
4242
std::unique_ptr<ImageClassifierCpp> impl;
4343
};
4444

45+
TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate() {
46+
return {.classification_options = {.max_results = -1},
47+
.base_options = {
48+
.compute_settings = {.cpu_settings = {.num_threads = -1}}}};
49+
}
50+
4551
std::unique_ptr<ImageClassifierOptionsCpp>
4652
CreateImageClassifierCppOptionsFromCOptions(
4753
const TfLiteImageClassifierOptions* c_options) {
@@ -55,28 +61,25 @@ CreateImageClassifierCppOptionsFromCOptions(
5561
else
5662
return nullptr;
5763

58-
// Without this check, in zero initialized TfLiteImageClassifierOptions (must
59-
// be done to prevent undefined behaviour)
60-
// c_options->base_options.compute_settings.num_threads should be explicitly
64+
// c_options->base_options.compute_settings.num_threads is expected to be
6165
// set to value > 0 or -1. Otherwise invoking
6266
// ImageClassifierCpp::CreateFromOptions() results in a not ok status.
63-
if (c_options->base_options.compute_settings.cpu_settings.num_threads > 0)
64-
cpp_options->mutable_base_options()
65-
->mutable_compute_settings()
66-
->mutable_tflite_settings()
67-
->mutable_cpu_settings()
68-
->set_num_threads(
69-
c_options->base_options.compute_settings.cpu_settings.num_threads);
70-
71-
for (int i = 0;
72-
i < c_options->classification_options.class_name_blacklist.length; i++)
67+
cpp_options->mutable_base_options()
68+
->mutable_compute_settings()
69+
->mutable_tflite_settings()
70+
->mutable_cpu_settings()
71+
->set_num_threads(
72+
c_options->base_options.compute_settings.cpu_settings.num_threads);
73+
74+
for (int i = 0; i < c_options->classification_options.label_denylist.length;
75+
i++)
7376
cpp_options->add_class_name_blacklist(
74-
c_options->classification_options.class_name_blacklist.list[i]);
77+
c_options->classification_options.label_denylist.list[i]);
7578

76-
for (int i = 0;
77-
i < c_options->classification_options.class_name_whitelist.length; i++)
79+
for (int i = 0; i < c_options->classification_options.label_allowlist.length;
80+
i++)
7881
cpp_options->add_class_name_whitelist(
79-
c_options->classification_options.class_name_whitelist.list[i]);
82+
c_options->classification_options.label_allowlist.list[i]);
8083

8184
// Check needed since setting a nullptr for this field results in a segfault
8285
// on invocation of ImageClassifierCpp::CreateFromOptions().
@@ -85,13 +88,10 @@ CreateImageClassifierCppOptionsFromCOptions(
8588
c_options->classification_options.display_names_local);
8689
}
8790

88-
// Without this check, in zero initialized TfLiteImageClassifierOptions (must
89-
// be done to prevent undefined behaviour)
90-
// c_options->classification_options.max_results should be explicitly set to
91-
// value > 0. Otherwise invocation of ImageClassifierCpp::CreateFromOptions()
92-
// will return a not ok status.
93-
if (c_options->classification_options.max_results > 0)
94-
cpp_options->set_max_results(c_options->classification_options.max_results);
91+
// c_options->classification_options.max_results is expected to be set to -1
92+
// or any value > 0. Otherwise invoking
93+
// ImageClassifierCpp::CreateFromOptions() results in a not ok status.
94+
cpp_options->set_max_results(c_options->classification_options.max_results);
9595

9696
cpp_options->set_score_threshold(
9797
c_options->classification_options.score_threshold);

tensorflow_lite_support/c/task/vision/image_classifier.h

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,19 @@ limitations under the License.
3535
/// Usage with Model File Path:
3636
/// <pre><code>
3737
/// // Create the model
38-
/// Zero initialize options to avoid undefined behaviour due to garbage values
39-
/// for members
40-
/// TfLiteImageClassifierOptions options = {0};
38+
/// Using the options initialized with default values returned by
39+
/// TfLiteImageClassifierOptionsCreate() makes sure that there will be no
40+
/// undefined behaviour due to garbage values in unitialized members.
41+
/// TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
42+
///
43+
/// Set the model file path in options
4144
/// options.base_options.model_file.file_path = "/path/to/model.tflite";
42-
/// TfLiteImageClassifier* image_classifier =
45+
///
46+
/// If need be, set values for any options to customize behaviour.
47+
/// options.base_options.compute_settings.cpu_settings.num_threads = 3
48+
///
49+
/// Create TfLiteImageClassifier using the options.
50+
/// TfLiteImageClassifier* image_classifier =
4351
/// TfLiteImageClassifierFromOptions(&options);
4452
///
4553
/// Classify an image
@@ -62,26 +70,37 @@ typedef struct TfLiteImageClassifierOptions {
6270
TfLiteBaseOptions base_options;
6371
} TfLiteImageClassifierOptions;
6472

73+
// Creates and returns TfLiteImageClassifierOptions initialized with default
74+
// values. Default values are as follows:
75+
// 1. .classification_options.max_results = -1, which returns all classification
76+
// categories by default.
77+
// 2. .base_options.compute_settings.tflite_settings.cpu_settings.num_threads =
78+
// -1, which makes the TFLite runtime choose the value.
79+
// 3. .classification_options.score_threshold = 0
80+
// 4. All pointers like .base_options.model_file.file_path,
81+
// .base_options.classification_options.display_names_local,
82+
// .classification_options.label_allowlist.list,
83+
// options.classification_options.label_denylist.list are NULL.
84+
// 5. All other integer values are initialized to 0.
85+
TfLiteImageClassifierOptions TfLiteImageClassifierOptionsCreate();
86+
6587
// Creates TfLiteImageClassifier from options.
66-
// base_options.model_file.file_path in TfLiteImageClassifierOptions should be
88+
// .base_options.model_file.file_path in TfLiteImageClassifierOptions should be
6789
// set to the path of the tflite model you wish to create the
6890
// TfLiteImageClassifier with.
6991
// Returns nullptr under the following circumstances:
7092
// 1. file doesn't exist or is not a well formatted.
7193
// 2. options is nullptr.
72-
// 3. Both options.classification_options.class_name_blacklist and
73-
// options.classification_options.class_name_blacklist are non empty. These
94+
// 3. Both options.classification_options.label_denylist and
95+
// options.classification_options.label_allowlist are non empty. These
7496
// fields are mutually exclusive.
7597
//
76-
// If
77-
// options->base_options.compute_settings.tflite_settings.cpu_settings.num_threads
78-
// <= 0, it will be set to a default of -1 which indicates the TFLite runtime to
79-
// choose the value.
80-
//
81-
// TfLiteImageClassifierOptions must be zero initialized to avoid seg faults.
82-
//
83-
// TODO(prianka): create default TfLiteImageClassifierOptions with default
84-
// values.
98+
// Create TfLiteImageClassifierOptions using
99+
// TfLiteImageClassifierOptionsCreate(). If need be, you can change the default
100+
// values of options for customizing classification, If options are not created
101+
// in the aforementioned way, you have to make sure that all members are
102+
// initialized to respective default values and all pointer members are zero
103+
// initialized to avoid any undefined behaviour.
85104
TfLiteImageClassifier* TfLiteImageClassifierFromOptions(
86105
const TfLiteImageClassifierOptions* options);
87106

tensorflow_lite_support/c/test/task/vision/image_classifier_test.cc

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ StatusOr<ImageData> LoadImage(const char* image_name) {
4848
class ImageClassifierFromOptionsTest : public tflite_shims::testing::Test {};
4949

5050
TEST_F(ImageClassifierFromOptionsTest, FailsWithMissingModelPath) {
51-
TfLiteImageClassifierOptions options = {{{0}}};
51+
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
5252
TfLiteImageClassifier* image_classifier =
5353
TfLiteImageClassifierFromOptions(&options);
5454
EXPECT_EQ(image_classifier, nullptr);
@@ -58,7 +58,7 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithModelPath) {
5858
std::string model_path =
5959
JoinPath("./" /*test src dir*/, kTestDataDirectory,
6060
kMobileNetQuantizedWithMetadata);
61-
TfLiteImageClassifierOptions options = {{{0}}};
61+
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
6262
options.base_options.model_file.file_path = model_path.data();
6363
TfLiteImageClassifier* image_classifier =
6464
TfLiteImageClassifierFromOptions(&options);
@@ -71,7 +71,7 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreads) {
7171
std::string model_path =
7272
JoinPath("./" /*test src dir*/, kTestDataDirectory,
7373
kMobileNetQuantizedWithMetadata);
74-
TfLiteImageClassifierOptions options = {{{0}}};
74+
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
7575
options.base_options.model_file.file_path = model_path.data();
7676
options.base_options.compute_settings.cpu_settings.num_threads = 3;
7777
TfLiteImageClassifier* image_classifier =
@@ -82,23 +82,21 @@ TEST_F(ImageClassifierFromOptionsTest, SucceedsWithNumberOfThreads) {
8282
}
8383

8484
TEST_F(ImageClassifierFromOptionsTest,
85-
FailsWithClassNameBlackListAndClassNameWhiteList) {
85+
FailsWithClassNameDenyListAndClassNameAllowList) {
8686
std::string model_path =
8787
JoinPath("./" /*test src dir*/, kTestDataDirectory,
8888
kMobileNetQuantizedWithMetadata);
8989

90-
TfLiteImageClassifierOptions options = {{{0}}};
90+
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
9191
options.base_options.model_file.file_path = model_path.data();
9292

93-
const char* class_name_blacklist[] = {"brambling"};
94-
options.classification_options.class_name_blacklist.list =
95-
class_name_blacklist;
96-
options.classification_options.class_name_blacklist.length = 1;
93+
const char* label_denylist[] = {"brambling"};
94+
options.classification_options.label_denylist.list = label_denylist;
95+
options.classification_options.label_denylist.length = 1;
9796

98-
const char* class_name_whitelist[] = {"cheeseburger"};
99-
options.classification_options.class_name_whitelist.list =
100-
class_name_whitelist;
101-
options.classification_options.class_name_whitelist.length = 1;
97+
const char* label_allowlist[] = {"cheeseburger"};
98+
options.classification_options.label_allowlist.list = label_allowlist;
99+
options.classification_options.label_allowlist.length = 1;
102100

103101
TfLiteImageClassifier* image_classifier =
104102
TfLiteImageClassifierFromOptions(&options);
@@ -114,7 +112,7 @@ class ImageClassifierClassifyTest : public tflite_shims::testing::Test {
114112
JoinPath("./" /*test src dir*/, kTestDataDirectory,
115113
kMobileNetQuantizedWithMetadata);
116114

117-
TfLiteImageClassifierOptions options = {{{0}}};
115+
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
118116
options.base_options.model_file.file_path = model_path.data();
119117
image_classifier = TfLiteImageClassifierFromOptions(&options);
120118

@@ -197,19 +195,18 @@ TEST_F(ImageClassifierClassifyTest, FailsWithRoiOutsideImageBounds) {
197195
}
198196

199197
TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
200-
SucceedsWithClassNameBlackList) {
201-
const char* blacklisted_label_name = "cheeseburger";
198+
SucceedsWithClassNameDenyList) {
199+
const char* denylisted_label_name = "cheeseburger";
202200
std::string model_path =
203201
JoinPath("./" /*test src dir*/, kTestDataDirectory,
204202
kMobileNetQuantizedWithMetadata);
205203

206-
TfLiteImageClassifierOptions options = {{{0}}};
204+
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
207205
options.base_options.model_file.file_path = model_path.data();
208206

209-
const char* class_name_blacklist[] = {blacklisted_label_name};
210-
options.classification_options.class_name_blacklist.list =
211-
class_name_blacklist;
212-
options.classification_options.class_name_blacklist.length = 1;
207+
const char* label_denylist[] = {denylisted_label_name};
208+
options.classification_options.label_denylist.list = label_denylist;
209+
options.classification_options.label_denylist.length = 1;
213210

214211
TfLiteImageClassifier* image_classifier =
215212
TfLiteImageClassifierFromOptions(&options);
@@ -233,7 +230,7 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
233230
EXPECT_GE(classification_result->classifications->size, 1);
234231
EXPECT_NE(classification_result->classifications->categories, nullptr);
235232
EXPECT_NE(strcmp(classification_result->classifications->categories[0].label,
236-
blacklisted_label_name),
233+
denylisted_label_name),
237234
0);
238235

239236
if (image_classifier) TfLiteImageClassifierDelete(image_classifier);
@@ -242,20 +239,19 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
242239
}
243240

244241
TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
245-
SucceedsWithClassNameWhiteList) {
246-
const char* whitelisted_label_name = "cheeseburger";
242+
SucceedsWithClassNameAllowList) {
243+
const char* allowlisted_label_name = "cheeseburger";
247244
std::string model_path =
248245
JoinPath("./" /*test src dir*/, kTestDataDirectory,
249246
kMobileNetQuantizedWithMetadata)
250247
.data();
251248

252-
TfLiteImageClassifierOptions options = {{{0}}};
249+
TfLiteImageClassifierOptions options = TfLiteImageClassifierOptionsCreate();
253250
options.base_options.model_file.file_path = model_path.data();
254251

255-
const char* class_name_whitelist[] = {whitelisted_label_name};
256-
options.classification_options.class_name_whitelist.list =
257-
class_name_whitelist;
258-
options.classification_options.class_name_whitelist.length = 1;
252+
const char* label_allowlist[] = {allowlisted_label_name};
253+
options.classification_options.label_allowlist.list = label_allowlist;
254+
options.classification_options.label_allowlist.length = 1;
259255

260256
TfLiteImageClassifier* image_classifier =
261257
TfLiteImageClassifierFromOptions(&options);
@@ -279,7 +275,7 @@ TEST(ImageClassifierWithUserDefinedOptionsClassifyTest,
279275
EXPECT_GE(classification_result->classifications->size, 1);
280276
EXPECT_NE(classification_result->classifications->categories, nullptr);
281277
EXPECT_EQ(strcmp(classification_result->classifications->categories[0].label,
282-
whitelisted_label_name),
278+
allowlisted_label_name),
283279
0);
284280

285281
if (image_classifier) TfLiteImageClassifierDelete(image_classifier);

0 commit comments

Comments
 (0)