Skip to content

Commit aed9982

Browse files
ziyeqinghantflite-support-robot
authored andcommitted
Migrate NLClassifier to using TextPreprocessor
PiperOrigin-RevId: 410445262
1 parent a278686 commit aed9982

File tree

4 files changed

+197
-34
lines changed

4 files changed

+197
-34
lines changed

tensorflow_lite_support/cc/task/processor/text_preprocessor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ absl::Status TextPreprocessor::BertPreprocess(const std::string& input_text) {
214214

215215
absl::Status TextPreprocessor::RegexPreprocess(const std::string& input_text) {
216216
TfLiteTensor* input_tensor = GetTensor();
217-
auto regex_tokenizer = dynamic_cast<RegexTokenizer*>(tokenizer_.get());
217+
auto regex_tokenizer = std::unique_ptr<RegexTokenizer>(
218+
dynamic_cast<RegexTokenizer*>(tokenizer_.release()));
218219

219220
// |<-------sentence_length-------->|
220221
// input_tensor <START>, t1, t2... <PAD>, <PAD>...

tensorflow_lite_support/cc/task/text/nlclassifier/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ cc_library_with_tflite(
2020
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
2121
"//tensorflow_lite_support/cc/task/core:base_task_api",
2222
"//tensorflow_lite_support/cc/task/core:task_api_factory",
23-
"//tensorflow_lite_support/cc/task/processor:text_preprocessor",
2423
],
2524
deps = [
2625
"//tensorflow_lite_support/cc:common",
@@ -29,7 +28,10 @@ cc_library_with_tflite(
2928
"//tensorflow_lite_support/cc/task/core:category",
3029
"//tensorflow_lite_support/cc/task/core:task_utils",
3130
"//tensorflow_lite_support/cc/task/text/proto:nl_classifier_options_proto_inc",
31+
"//tensorflow_lite_support/cc/text/tokenizers:regex_tokenizer",
32+
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer",
3233
"//tensorflow_lite_support/cc/utils:common_utils",
34+
"//tensorflow_lite_support/metadata/cc:metadata_extractor",
3335
"@com_google_absl//absl/algorithm:container",
3436
"@com_google_absl//absl/base:core_headers",
3537
"@com_google_absl//absl/status",

tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.cc

Lines changed: 184 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ limitations under the License.
3535
#include "tensorflow_lite_support/cc/task/core/category.h"
3636
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
3737
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
38+
#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
39+
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
3840
#include "tensorflow_lite_support/cc/utils/common_utils.h"
3941

4042
namespace tflite {
@@ -49,16 +51,22 @@ using ::tflite::TensorMetadata;
4951
using ::tflite::support::CreateStatusWithPayload;
5052
using ::tflite::support::StatusOr;
5153
using ::tflite::support::TfLiteSupportStatus;
54+
using ::tflite::support::text::tokenizer::RegexTokenizer;
55+
using ::tflite::support::text::tokenizer::Tokenizer;
56+
using ::tflite::support::text::tokenizer::TokenizerResult;
5257
using ::tflite::support::utils::LoadVocabFromBuffer;
5358
using ::tflite::task::core::Category;
5459
using ::tflite::task::core::Dequantize;
5560
using ::tflite::task::core::GetStringAtIndex;
61+
using ::tflite::task::core::PopulateTensor;
5662
using ::tflite::task::core::TaskAPIFactory;
5763
// To differenciate it with the struct option,
5864
// tflite::task::text::nl_classifier::NLClassifierOptions.
5965
using NLClassifierProtoOptions = ::tflite::task::text::NLClassifierOptions;
6066

6167
namespace {
68+
constexpr int kRegexTokenizerInputTensorIndex = 0;
69+
constexpr int kRegexTokenizerProcessUnitIndex = 0;
6270

6371
absl::Status SanityCheckOptions(const NLClassifierProtoOptions& options) {
6472
if (!options.has_base_options()) {
@@ -69,6 +77,78 @@ absl::Status SanityCheckOptions(const NLClassifierProtoOptions& options) {
6977
return absl::OkStatus();
7078
}
7179

80+
StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
81+
const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
82+
associated_files,
83+
const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
84+
if (associated_files == nullptr || associated_files->size() < 1 ||
85+
associated_files->Get(0)->name() == nullptr) {
86+
return CreateStatusWithPayload(
87+
absl::StatusCode::kInvalidArgument,
88+
"Invalid vocab_file from input process unit.",
89+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
90+
}
91+
ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
92+
metadata_extractor->GetAssociatedFile(
93+
associated_files->Get(0)->name()->str()));
94+
return vocab_buffer;
95+
}
96+
97+
StatusOr<std::unique_ptr<Tokenizer>> CreateRegexTokenizerFromProcessUnit(
98+
const tflite::ProcessUnit* tokenizer_process_unit,
99+
const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
100+
if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) {
101+
return CreateStatusWithPayload(
102+
absl::StatusCode::kInvalidArgument,
103+
"No metadata or input process unit found.",
104+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
105+
}
106+
107+
if (tokenizer_process_unit->options_type() !=
108+
ProcessUnitOptions_RegexTokenizerOptions) {
109+
return CreateStatusWithPayload(
110+
absl::StatusCode::kNotFound,
111+
absl::StrCat(
112+
"Incorrect options_type:", tokenizer_process_unit->options_type(),
113+
" need RegexTokenizerOptions."),
114+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
115+
}
116+
117+
const tflite::RegexTokenizerOptions* options =
118+
tokenizer_process_unit->options_as<RegexTokenizerOptions>();
119+
ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
120+
CheckAndLoadFirstAssociatedFile(options->vocab_file(),
121+
metadata_extractor));
122+
if (options->delim_regex_pattern() == nullptr) {
123+
return CreateStatusWithPayload(
124+
absl::StatusCode::kInvalidArgument,
125+
"Invalid delim_regex_pattern from input process unit.",
126+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
127+
}
128+
129+
std::unique_ptr<RegexTokenizer> regex_tokenizer =
130+
absl::make_unique<RegexTokenizer>(options->delim_regex_pattern()->str(),
131+
vocab_buffer.data(),
132+
vocab_buffer.size());
133+
134+
int unknown_token_id = 0;
135+
if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) {
136+
return CreateStatusWithPayload(
137+
absl::StatusCode::kInvalidArgument,
138+
"RegexTokenizer doesn't have <UNKNOWN> token.",
139+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
140+
}
141+
142+
int pad_token_id = 0;
143+
if (!regex_tokenizer->GetPadToken(&pad_token_id)) {
144+
return CreateStatusWithPayload(
145+
absl::StatusCode::kInvalidArgument,
146+
"RegexTokenizer doesn't have <PAD> token.",
147+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
148+
}
149+
return std::move(regex_tokenizer);
150+
}
151+
72152
} // namespace
73153

74154
const NLClassifierOptions& NLClassifier::GetOptions() const {
@@ -121,7 +201,58 @@ std::vector<Category> NLClassifier::Classify(const std::string& text) {
121201

122202
absl::Status NLClassifier::Preprocess(
123203
const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
124-
return preprocessor_->Preprocess(input);
204+
TfLiteTensor* input_tensor = FindTensorWithNameOrIndex(
205+
input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
206+
struct_options_.input_tensor_name, struct_options_.input_tensor_index);
207+
if (input_tensor == nullptr) {
208+
return CreateStatusWithPayload(
209+
absl::StatusCode::kInvalidArgument,
210+
"No input tensor found from NLClassifierOptions.",
211+
TfLiteSupportStatus::kInputTensorNotFoundError);
212+
}
213+
214+
if (HasRegexTokenizerMetadata()) {
215+
// |<-------sentence_length-------->|
216+
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
217+
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
218+
// found in tokenizer vocab.
219+
TokenizerResult result = tokenizer_->Tokenize(input);
220+
221+
size_t max_sentence_length = input_tensor->dims->size == 2
222+
? input_tensor->dims->data[1]
223+
: input_tensor->dims->data[0];
224+
225+
int unknown_token_id = 0;
226+
tokenizer_->GetUnknownToken(&unknown_token_id);
227+
228+
int pad_token_id = 0;
229+
tokenizer_->GetPadToken(&pad_token_id);
230+
231+
std::vector<int> input_tokens(max_sentence_length, pad_token_id);
232+
int start_token_id = 0;
233+
size_t input_token_index = 0;
234+
if (tokenizer_->GetStartToken(&start_token_id)) {
235+
input_tokens[0] = start_token_id;
236+
input_token_index = 1;
237+
}
238+
239+
for (size_t i = 0; (i < result.subwords.size()) &&
240+
(input_token_index < max_sentence_length);
241+
++i, ++input_token_index) {
242+
const std::string& token = result.subwords[i];
243+
int token_id = 0;
244+
if (tokenizer_->LookupId(token, &token_id)) {
245+
input_tokens[input_token_index] = token_id;
246+
} else {
247+
input_tokens[input_token_index] = unknown_token_id;
248+
}
249+
}
250+
251+
RETURN_IF_ERROR(PopulateTensor(input_tokens, input_tensor));
252+
} else {
253+
RETURN_IF_ERROR(PopulateTensor(input, input_tensor));
254+
}
255+
return absl::OkStatus();
125256
}
126257

127258
StatusOr<std::vector<Category>> NLClassifier::Postprocess(
@@ -196,23 +327,38 @@ absl::Status NLClassifier::Initialize(
196327

197328
absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
198329
struct_options_ = options;
199-
200-
int input_index = FindTensorIndex(
330+
// input tensor should be type STRING
331+
auto input_tensor = FindTensorWithNameOrIndex(
201332
GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(),
202333
options.input_tensor_name, options.input_tensor_index);
203-
204-
if (input_index < 0 || input_index >= GetInputCount()) {
334+
if (input_tensor == nullptr) {
205335
return CreateStatusWithPayload(
206336
StatusCode::kInvalidArgument,
207337
absl::StrCat("No input tensor found with name ",
208338
options.input_tensor_name, " or at index ",
209339
options.input_tensor_index),
210340
TfLiteSupportStatus::kInputTensorNotFoundError);
211341
}
212-
213-
// Create preprocessor.
214-
ASSIGN_OR_RETURN(preprocessor_, processor::TextPreprocessor::Create(
215-
GetTfLiteEngine(), {input_index}));
342+
if (HasRegexTokenizerMetadata()) {
343+
if (input_tensor->type != kTfLiteInt32) {
344+
return CreateStatusWithPayload(
345+
StatusCode::kInvalidArgument,
346+
absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
347+
". Requested INT32, got ",
348+
TfLiteTypeGetName(input_tensor->type), "."),
349+
TfLiteSupportStatus::kInvalidInputTensorTypeError);
350+
}
351+
RETURN_IF_ERROR(SetupRegexTokenizer());
352+
} else {
353+
if (input_tensor->type != kTfLiteString) {
354+
return CreateStatusWithPayload(
355+
StatusCode::kInvalidArgument,
356+
absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
357+
". Requested STRING, got ",
358+
TfLiteTypeGetName(input_tensor->type), "."),
359+
TfLiteSupportStatus::kInvalidInputTensorTypeError);
360+
}
361+
}
216362

217363
// output score tensor should be type
218364
// UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
@@ -334,6 +480,35 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
334480
return std::move(nl_classifier);
335481
}
336482

483+
bool NLClassifier::HasRegexTokenizerMetadata() {
484+
const TensorMetadata* input_tensor_metadata =
485+
GetMetadataExtractor()->GetInputTensorMetadata(
486+
kRegexTokenizerInputTensorIndex);
487+
if (input_tensor_metadata == nullptr) {
488+
return false;
489+
}
490+
tflite::support::StatusOr<const tflite::ProcessUnit*> status =
491+
GetMetadataExtractor()->FindFirstProcessUnit(
492+
*input_tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions);
493+
return status.ok() ? status.value() != nullptr : false;
494+
}
495+
496+
absl::Status NLClassifier::SetupRegexTokenizer() {
497+
ASSIGN_OR_RETURN(
498+
std::unique_ptr<Tokenizer> base_tokenizer,
499+
CreateRegexTokenizerFromProcessUnit(
500+
GetMetadataExtractor()
501+
->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
502+
->process_units()
503+
->Get(kRegexTokenizerProcessUnitIndex),
504+
GetMetadataExtractor()));
505+
506+
tokenizer_ = std::unique_ptr<RegexTokenizer>(
507+
dynamic_cast<RegexTokenizer*>(base_tokenizer.release()));
508+
509+
return absl::OkStatus();
510+
}
511+
337512
} // namespace nlclassifier
338513
} // namespace text
339514
} // namespace task

tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ limitations under the License.
3434
#include "tensorflow_lite_support/cc/port/statusor.h"
3535
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
3636
#include "tensorflow_lite_support/cc/task/core/category.h"
37-
#include "tensorflow_lite_support/cc/task/processor/text_preprocessor.h"
3837
#include "tensorflow_lite_support/cc/task/text/proto/nl_classifier_options_proto_inc.h"
38+
#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
3939

4040
namespace tflite {
4141
namespace task {
@@ -181,41 +181,25 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
181181
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
182182
metadata_array,
183183
const std::string& name, int index) {
184-
int tensor_index = FindTensorIndex(tensors, metadata_array, name, index);
185-
return tensor_index >= 0 && tensor_index < tensors.size()
186-
? tensors[tensor_index]
187-
: nullptr;
188-
}
189-
190-
// Gets the tensor index of the specified tensor name from a vector of tensors
191-
// Return nullptr if no tensor is found by name (metadata tensor name or model
192-
// tensor name).
193-
template <typename TensorType>
194-
static int FindTensorIndex(
195-
const std::vector<TensorType*>& tensors,
196-
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
197-
metadata_array,
198-
const std::string& name, int default_index) {
199184
if (metadata_array != nullptr && metadata_array->size() == tensors.size()) {
200185
for (int i = 0; i < metadata_array->size(); i++) {
201186
if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) {
202-
return i;
187+
return tensors[i];
203188
}
204189
}
205190
}
206191

207-
for (int i = 0; i < tensors.size(); i++) {
208-
TensorType* tensor = tensors[i];
192+
for (TensorType* tensor : tensors) {
209193
if (tensor->name == name) {
210-
return i;
194+
return tensor;
211195
}
212196
}
213-
return default_index;
197+
return index >= 0 && index < tensors.size() ? tensors[index] : nullptr;
214198
}
215199

216200
private:
217-
std::unique_ptr<tflite::task::processor::TextPreprocessor> preprocessor_ =
218-
nullptr;
201+
bool HasRegexTokenizerMetadata();
202+
absl::Status SetupRegexTokenizer();
219203

220204
std::unique_ptr<tflite::task::text::NLClassifierOptions> proto_options_;
221205

@@ -226,6 +210,7 @@ class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
226210
// labels vector initialized from output tensor's associated file, if one
227211
// exists.
228212
std::unique_ptr<std::vector<std::string>> labels_vector_;
213+
std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_;
229214
};
230215

231216
} // namespace nlclassifier

0 commit comments

Comments
 (0)