Skip to content

Commit a278686

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Migrate NLClassifier to using TextPreprocessor
PiperOrigin-RevId: 410426912
1 parent ded908c commit a278686

File tree

4 files changed

+34
-197
lines changed

4 files changed

+34
-197
lines changed

tensorflow_lite_support/cc/task/processor/text_preprocessor.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,7 @@ absl::Status TextPreprocessor::BertPreprocess(const std::string& input_text) {
214214

215215
absl::Status TextPreprocessor::RegexPreprocess(const std::string& input_text) {
216216
TfLiteTensor* input_tensor = GetTensor();
217-
auto regex_tokenizer = std::unique_ptr<RegexTokenizer>(
218-
dynamic_cast<RegexTokenizer*>(tokenizer_.release()));
217+
auto regex_tokenizer = dynamic_cast<RegexTokenizer*>(tokenizer_.get());
219218

220219
// |<-------sentence_length-------->|
221220
// input_tensor <START>, t1, t2... <PAD>, <PAD>...

tensorflow_lite_support/cc/task/text/nlclassifier/BUILD

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ cc_library_with_tflite(
2020
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
2121
"//tensorflow_lite_support/cc/task/core:base_task_api",
2222
"//tensorflow_lite_support/cc/task/core:task_api_factory",
23+
"//tensorflow_lite_support/cc/task/processor:text_preprocessor",
2324
],
2425
deps = [
2526
"//tensorflow_lite_support/cc:common",
@@ -28,10 +29,7 @@ cc_library_with_tflite(
2829
"//tensorflow_lite_support/cc/task/core:category",
2930
"//tensorflow_lite_support/cc/task/core:task_utils",
3031
"//tensorflow_lite_support/cc/task/text/proto:nl_classifier_options_proto_inc",
31-
"//tensorflow_lite_support/cc/text/tokenizers:regex_tokenizer",
32-
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer",
3332
"//tensorflow_lite_support/cc/utils:common_utils",
34-
"//tensorflow_lite_support/metadata/cc:metadata_extractor",
3533
"@com_google_absl//absl/algorithm:container",
3634
"@com_google_absl//absl/base:core_headers",
3735
"@com_google_absl//absl/status",

tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc

Lines changed: 9 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ limitations under the License.
3535
#include "tensorflow_lite_support/cc/task/core/category.h"
3636
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
3737
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
38-
#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
39-
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
4038
#include "tensorflow_lite_support/cc/utils/common_utils.h"
4139

4240
namespace tflite {
@@ -51,22 +49,16 @@ using ::tflite::TensorMetadata;
5149
using ::tflite::support::CreateStatusWithPayload;
5250
using ::tflite::support::StatusOr;
5351
using ::tflite::support::TfLiteSupportStatus;
54-
using ::tflite::support::text::tokenizer::RegexTokenizer;
55-
using ::tflite::support::text::tokenizer::Tokenizer;
56-
using ::tflite::support::text::tokenizer::TokenizerResult;
5752
using ::tflite::support::utils::LoadVocabFromBuffer;
5853
using ::tflite::task::core::Category;
5954
using ::tflite::task::core::Dequantize;
6055
using ::tflite::task::core::GetStringAtIndex;
61-
using ::tflite::task::core::PopulateTensor;
6256
using ::tflite::task::core::TaskAPIFactory;
6357
// To differenciate it with the struct option,
6458
// tflite::task::text::nl_classifier::NLClassifierOptions.
6559
using NLClassifierProtoOptions = ::tflite::task::text::NLClassifierOptions;
6660

6761
namespace {
68-
constexpr int kRegexTokenizerInputTensorIndex = 0;
69-
constexpr int kRegexTokenizerProcessUnitIndex = 0;
7062

7163
absl::Status SanityCheckOptions(const NLClassifierProtoOptions& options) {
7264
if (!options.has_base_options()) {
@@ -77,78 +69,6 @@ absl::Status SanityCheckOptions(const NLClassifierProtoOptions& options) {
7769
return absl::OkStatus();
7870
}
7971

80-
StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
81-
const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
82-
associated_files,
83-
const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
84-
if (associated_files == nullptr || associated_files->size() < 1 ||
85-
associated_files->Get(0)->name() == nullptr) {
86-
return CreateStatusWithPayload(
87-
absl::StatusCode::kInvalidArgument,
88-
"Invalid vocab_file from input process unit.",
89-
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
90-
}
91-
ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
92-
metadata_extractor->GetAssociatedFile(
93-
associated_files->Get(0)->name()->str()));
94-
return vocab_buffer;
95-
}
96-
97-
StatusOr<std::unique_ptr<Tokenizer>> CreateRegexTokenizerFromProcessUnit(
98-
const tflite::ProcessUnit* tokenizer_process_unit,
99-
const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
100-
if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) {
101-
return CreateStatusWithPayload(
102-
absl::StatusCode::kInvalidArgument,
103-
"No metadata or input process unit found.",
104-
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
105-
}
106-
107-
if (tokenizer_process_unit->options_type() !=
108-
ProcessUnitOptions_RegexTokenizerOptions) {
109-
return CreateStatusWithPayload(
110-
absl::StatusCode::kNotFound,
111-
absl::StrCat(
112-
"Incorrect options_type:", tokenizer_process_unit->options_type(),
113-
" need RegexTokenizerOptions."),
114-
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
115-
}
116-
117-
const tflite::RegexTokenizerOptions* options =
118-
tokenizer_process_unit->options_as<RegexTokenizerOptions>();
119-
ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
120-
CheckAndLoadFirstAssociatedFile(options->vocab_file(),
121-
metadata_extractor));
122-
if (options->delim_regex_pattern() == nullptr) {
123-
return CreateStatusWithPayload(
124-
absl::StatusCode::kInvalidArgument,
125-
"Invalid delim_regex_pattern from input process unit.",
126-
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
127-
}
128-
129-
std::unique_ptr<RegexTokenizer> regex_tokenizer =
130-
absl::make_unique<RegexTokenizer>(options->delim_regex_pattern()->str(),
131-
vocab_buffer.data(),
132-
vocab_buffer.size());
133-
134-
int unknown_token_id = 0;
135-
if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) {
136-
return CreateStatusWithPayload(
137-
absl::StatusCode::kInvalidArgument,
138-
"RegexTokenizer doesn't have <UNKNOWN> token.",
139-
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
140-
}
141-
142-
int pad_token_id = 0;
143-
if (!regex_tokenizer->GetPadToken(&pad_token_id)) {
144-
return CreateStatusWithPayload(
145-
absl::StatusCode::kInvalidArgument,
146-
"RegexTokenizer doesn't have <PAD> token.",
147-
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
148-
}
149-
return std::move(regex_tokenizer);
150-
}
151-
15272
} // namespace
15373

15474
const NLClassifierOptions& NLClassifier::GetOptions() const {
@@ -201,58 +121,7 @@ std::vector<Category> NLClassifier::Classify(const std::string& text) {
201121

202122
absl::Status NLClassifier::Preprocess(
203123
const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
204-
TfLiteTensor* input_tensor = FindTensorWithNameOrIndex(
205-
input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
206-
struct_options_.input_tensor_name, struct_options_.input_tensor_index);
207-
if (input_tensor == nullptr) {
208-
return CreateStatusWithPayload(
209-
absl::StatusCode::kInvalidArgument,
210-
"No input tensor found from NLClassifierOptions.",
211-
TfLiteSupportStatus::kInputTensorNotFoundError);
212-
}
213-
214-
if (HasRegexTokenizerMetadata()) {
215-
// |<-------sentence_length-------->|
216-
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
217-
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
218-
// found in tokenizer vocab.
219-
TokenizerResult result = tokenizer_->Tokenize(input);
220-
221-
size_t max_sentence_length = input_tensor->dims->size == 2
222-
? input_tensor->dims->data[1]
223-
: input_tensor->dims->data[0];
224-
225-
int unknown_token_id = 0;
226-
tokenizer_->GetUnknownToken(&unknown_token_id);
227-
228-
int pad_token_id = 0;
229-
tokenizer_->GetPadToken(&pad_token_id);
230-
231-
std::vector<int> input_tokens(max_sentence_length, pad_token_id);
232-
int start_token_id = 0;
233-
size_t input_token_index = 0;
234-
if (tokenizer_->GetStartToken(&start_token_id)) {
235-
input_tokens[0] = start_token_id;
236-
input_token_index = 1;
237-
}
238-
239-
for (size_t i = 0; (i < result.subwords.size()) &&
240-
(input_token_index < max_sentence_length);
241-
++i, ++input_token_index) {
242-
const std::string& token = result.subwords[i];
243-
int token_id = 0;
244-
if (tokenizer_->LookupId(token, &token_id)) {
245-
input_tokens[input_token_index] = token_id;
246-
} else {
247-
input_tokens[input_token_index] = unknown_token_id;
248-
}
249-
}
250-
251-
RETURN_IF_ERROR(PopulateTensor(input_tokens, input_tensor));
252-
} else {
253-
RETURN_IF_ERROR(PopulateTensor(input, input_tensor));
254-
}
255-
return absl::OkStatus();
124+
return preprocessor_->Preprocess(input);
256125
}
257126

258127
StatusOr<std::vector<Category>> NLClassifier::Postprocess(
@@ -327,38 +196,23 @@ absl::Status NLClassifier::Initialize(
327196

328197
absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
329198
struct_options_ = options;
330-
// input tensor should be type STRING
331-
auto input_tensor = FindTensorWithNameOrIndex(
199+
200+
int input_index = FindTensorIndex(
332201
GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(),
333202
options.input_tensor_name, options.input_tensor_index);
334-
if (input_tensor == nullptr) {
203+
204+
if (input_index < 0 || input_index >= GetInputCount()) {
335205
return CreateStatusWithPayload(
336206
StatusCode::kInvalidArgument,
337207
absl::StrCat("No input tensor found with name ",
338208
options.input_tensor_name, " or at index ",
339209
options.input_tensor_index),
340210
TfLiteSupportStatus::kInputTensorNotFoundError);
341211
}
342-
if (HasRegexTokenizerMetadata()) {
343-
if (input_tensor->type != kTfLiteInt32) {
344-
return CreateStatusWithPayload(
345-
StatusCode::kInvalidArgument,
346-
absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
347-
". Requested INT32, got ",
348-
TfLiteTypeGetName(input_tensor->type), "."),
349-
TfLiteSupportStatus::kInvalidInputTensorTypeError);
350-
}
351-
RETURN_IF_ERROR(SetupRegexTokenizer());
352-
} else {
353-
if (input_tensor->type != kTfLiteString) {
354-
return CreateStatusWithPayload(
355-
StatusCode::kInvalidArgument,
356-
absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
357-
". Requested STRING, got ",
358-
TfLiteTypeGetName(input_tensor->type), "."),
359-
TfLiteSupportStatus::kInvalidInputTensorTypeError);
360-
}
361-
}
212+
213+
// Create preprocessor.
214+
ASSIGN_OR_RETURN(preprocessor_, processor::TextPreprocessor::Create(
215+
GetTfLiteEngine(), {input_index}));
362216

363217
// output score tensor should be type
364218
// UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
@@ -480,35 +334,6 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
480334
return std::move(nl_classifier);
481335
}
482336

483-
bool NLClassifier::HasRegexTokenizerMetadata() {
484-
const TensorMetadata* input_tensor_metadata =
485-
GetMetadataExtractor()->GetInputTensorMetadata(
486-
kRegexTokenizerInputTensorIndex);
487-
if (input_tensor_metadata == nullptr) {
488-
return false;
489-
}
490-
tflite::support::StatusOr<const tflite::ProcessUnit*> status =
491-
GetMetadataExtractor()->FindFirstProcessUnit(
492-
*input_tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions);
493-
return status.ok() ? status.value() != nullptr : false;
494-
}
495-
496-
absl::Status NLClassifier::SetupRegexTokenizer() {
497-
ASSIGN_OR_RETURN(
498-
std::unique_ptr<Tokenizer> base_tokenizer,
499-
CreateRegexTokenizerFromProcessUnit(
500-
GetMetadataExtractor()
501-
->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
502-
->process_units()
503-
->Get(kRegexTokenizerProcessUnitIndex),
504-
GetMetadataExtractor()));
505-
506-
tokenizer_ = std::unique_ptr<RegexTokenizer>(
507-
dynamic_cast<RegexTokenizer*>(base_tokenizer.release()));
508-
509-
return absl::OkStatus();
510-
}
511-
512337
} // namespace nlclassifier
513338
} // namespace text
514339
} // namespace task

tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ limitations under the License.
3434
#include "tensorflow_lite_support/cc/port/statusor.h"
3535
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
3636
#include "tensorflow_lite_support/cc/task/core/category.h"
37+
#include "tensorflow_lite_support/cc/task/processor/text_preprocessor.h"
3738
#include "tensorflow_lite_support/cc/task/text/proto/nl_classifier_options_proto_inc.h"
38-
#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
3939

4040
namespace tflite {
4141
namespace task {
@@ -181,25 +181,41 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
181181
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
182182
metadata_array,
183183
const std::string& name, int index) {
184+
int tensor_index = FindTensorIndex(tensors, metadata_array, name, index);
185+
return tensor_index >= 0 && tensor_index < tensors.size()
186+
? tensors[tensor_index]
187+
: nullptr;
188+
}
189+
190+
// Gets the tensor index of the specified tensor name from a vector of tensors
191+
// Return nullptr if no tensor is found by name (metadata tensor name or model
192+
// tensor name).
193+
template <typename TensorType>
194+
static int FindTensorIndex(
195+
const std::vector<TensorType*>& tensors,
196+
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
197+
metadata_array,
198+
const std::string& name, int default_index) {
184199
if (metadata_array != nullptr && metadata_array->size() == tensors.size()) {
185200
for (int i = 0; i < metadata_array->size(); i++) {
186201
if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) {
187-
return tensors[i];
202+
return i;
188203
}
189204
}
190205
}
191206

192-
for (TensorType* tensor : tensors) {
207+
for (int i = 0; i < tensors.size(); i++) {
208+
TensorType* tensor = tensors[i];
193209
if (tensor->name == name) {
194-
return tensor;
210+
return i;
195211
}
196212
}
197-
return index >= 0 && index < tensors.size() ? tensors[index] : nullptr;
213+
return default_index;
198214
}
199215

200216
private:
201-
bool HasRegexTokenizerMetadata();
202-
absl::Status SetupRegexTokenizer();
217+
std::unique_ptr<tflite::task::processor::TextPreprocessor> preprocessor_ =
218+
nullptr;
203219

204220
std::unique_ptr<tflite::task::text::NLClassifierOptions> proto_options_;
205221

@@ -210,7 +226,6 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
210226
// labels vector initialized from output tensor's associated file, if one
211227
// exists.
212228
std::unique_ptr<std::vector<std::string>> labels_vector_;
213-
std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_;
214229
};
215230

216231
} // namespace nlclassifier

0 commit comments

Comments
 (0)