@@ -35,8 +35,6 @@ limitations under the License.
3535#include " tensorflow_lite_support/cc/task/core/category.h"
3636#include " tensorflow_lite_support/cc/task/core/task_api_factory.h"
3737#include " tensorflow_lite_support/cc/task/core/task_utils.h"
38- #include " tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
39- #include " tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
4038#include " tensorflow_lite_support/cc/utils/common_utils.h"
4139
4240namespace tflite {
@@ -51,22 +49,16 @@ using ::tflite::TensorMetadata;
5149using ::tflite::support::CreateStatusWithPayload;
5250using ::tflite::support::StatusOr;
5351using ::tflite::support::TfLiteSupportStatus;
54- using ::tflite::support::text::tokenizer::RegexTokenizer;
55- using ::tflite::support::text::tokenizer::Tokenizer;
56- using ::tflite::support::text::tokenizer::TokenizerResult;
5752using ::tflite::support::utils::LoadVocabFromBuffer;
5853using ::tflite::task::core::Category;
5954using ::tflite::task::core::Dequantize;
6055using ::tflite::task::core::GetStringAtIndex;
61- using ::tflite::task::core::PopulateTensor;
6256using ::tflite::task::core::TaskAPIFactory;
6357// To differenciate it with the struct option,
6458// tflite::task::text::nl_classifier::NLClassifierOptions.
6559using NLClassifierProtoOptions = ::tflite::task::text::NLClassifierOptions;
6660
6761namespace {
68- constexpr int kRegexTokenizerInputTensorIndex = 0 ;
69- constexpr int kRegexTokenizerProcessUnitIndex = 0 ;
7062
7163absl::Status SanityCheckOptions (const NLClassifierProtoOptions& options) {
7264 if (!options.has_base_options ()) {
@@ -77,78 +69,6 @@ absl::Status SanityCheckOptions(const NLClassifierProtoOptions& options) {
7769 return absl::OkStatus ();
7870}
7971
80- StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile (
81- const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
82- associated_files,
83- const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
84- if (associated_files == nullptr || associated_files->size () < 1 ||
85- associated_files->Get (0 )->name () == nullptr ) {
86- return CreateStatusWithPayload (
87- absl::StatusCode::kInvalidArgument ,
88- " Invalid vocab_file from input process unit." ,
89- TfLiteSupportStatus::kMetadataInvalidTokenizerError );
90- }
91- ASSIGN_OR_RETURN (absl::string_view vocab_buffer,
92- metadata_extractor->GetAssociatedFile (
93- associated_files->Get (0 )->name ()->str ()));
94- return vocab_buffer;
95- }
96-
97- StatusOr<std::unique_ptr<Tokenizer>> CreateRegexTokenizerFromProcessUnit (
98- const tflite::ProcessUnit* tokenizer_process_unit,
99- const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
100- if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr ) {
101- return CreateStatusWithPayload (
102- absl::StatusCode::kInvalidArgument ,
103- " No metadata or input process unit found." ,
104- TfLiteSupportStatus::kMetadataInvalidTokenizerError );
105- }
106-
107- if (tokenizer_process_unit->options_type () !=
108- ProcessUnitOptions_RegexTokenizerOptions) {
109- return CreateStatusWithPayload (
110- absl::StatusCode::kNotFound ,
111- absl::StrCat (
112- " Incorrect options_type:" , tokenizer_process_unit->options_type (),
113- " need RegexTokenizerOptions." ),
114- TfLiteSupportStatus::kMetadataInvalidTokenizerError );
115- }
116-
117- const tflite::RegexTokenizerOptions* options =
118- tokenizer_process_unit->options_as <RegexTokenizerOptions>();
119- ASSIGN_OR_RETURN (absl::string_view vocab_buffer,
120- CheckAndLoadFirstAssociatedFile (options->vocab_file (),
121- metadata_extractor));
122- if (options->delim_regex_pattern () == nullptr ) {
123- return CreateStatusWithPayload (
124- absl::StatusCode::kInvalidArgument ,
125- " Invalid delim_regex_pattern from input process unit." ,
126- TfLiteSupportStatus::kMetadataInvalidTokenizerError );
127- }
128-
129- std::unique_ptr<RegexTokenizer> regex_tokenizer =
130- absl::make_unique<RegexTokenizer>(options->delim_regex_pattern ()->str (),
131- vocab_buffer.data (),
132- vocab_buffer.size ());
133-
134- int unknown_token_id = 0 ;
135- if (!regex_tokenizer->GetUnknownToken (&unknown_token_id)) {
136- return CreateStatusWithPayload (
137- absl::StatusCode::kInvalidArgument ,
138- " RegexTokenizer doesn't have <UNKNOWN> token." ,
139- TfLiteSupportStatus::kMetadataInvalidTokenizerError );
140- }
141-
142- int pad_token_id = 0 ;
143- if (!regex_tokenizer->GetPadToken (&pad_token_id)) {
144- return CreateStatusWithPayload (
145- absl::StatusCode::kInvalidArgument ,
146- " RegexTokenizer doesn't have <PAD> token." ,
147- TfLiteSupportStatus::kMetadataInvalidTokenizerError );
148- }
149- return std::move (regex_tokenizer);
150- }
151-
15272} // namespace
15373
15474const NLClassifierOptions& NLClassifier::GetOptions () const {
@@ -201,58 +121,7 @@ std::vector<Category> NLClassifier::Classify(const std::string& text) {
201121
202122absl::Status NLClassifier::Preprocess (
203123 const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
204- TfLiteTensor* input_tensor = FindTensorWithNameOrIndex (
205- input_tensors, GetMetadataExtractor ()->GetInputTensorMetadata (),
206- struct_options_.input_tensor_name , struct_options_.input_tensor_index );
207- if (input_tensor == nullptr ) {
208- return CreateStatusWithPayload (
209- absl::StatusCode::kInvalidArgument ,
210- " No input tensor found from NLClassifierOptions." ,
211- TfLiteSupportStatus::kInputTensorNotFoundError );
212- }
213-
214- if (HasRegexTokenizerMetadata ()) {
215- // |<-------sentence_length-------->|
216- // input_tensor <START>, t1, t2... <PAD>, <PAD>...
217- // <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
218- // found in tokenizer vocab.
219- TokenizerResult result = tokenizer_->Tokenize (input);
220-
221- size_t max_sentence_length = input_tensor->dims ->size == 2
222- ? input_tensor->dims ->data [1 ]
223- : input_tensor->dims ->data [0 ];
224-
225- int unknown_token_id = 0 ;
226- tokenizer_->GetUnknownToken (&unknown_token_id);
227-
228- int pad_token_id = 0 ;
229- tokenizer_->GetPadToken (&pad_token_id);
230-
231- std::vector<int > input_tokens (max_sentence_length, pad_token_id);
232- int start_token_id = 0 ;
233- size_t input_token_index = 0 ;
234- if (tokenizer_->GetStartToken (&start_token_id)) {
235- input_tokens[0 ] = start_token_id;
236- input_token_index = 1 ;
237- }
238-
239- for (size_t i = 0 ; (i < result.subwords .size ()) &&
240- (input_token_index < max_sentence_length);
241- ++i, ++input_token_index) {
242- const std::string& token = result.subwords [i];
243- int token_id = 0 ;
244- if (tokenizer_->LookupId (token, &token_id)) {
245- input_tokens[input_token_index] = token_id;
246- } else {
247- input_tokens[input_token_index] = unknown_token_id;
248- }
249- }
250-
251- RETURN_IF_ERROR (PopulateTensor (input_tokens, input_tensor));
252- } else {
253- RETURN_IF_ERROR (PopulateTensor (input, input_tensor));
254- }
255- return absl::OkStatus ();
124+ return preprocessor_->Preprocess (input);
256125}
257126
258127StatusOr<std::vector<Category>> NLClassifier::Postprocess (
@@ -327,38 +196,23 @@ absl::Status NLClassifier::Initialize(
327196
328197absl::Status NLClassifier::Initialize (const NLClassifierOptions& options) {
329198 struct_options_ = options;
330- // input tensor should be type STRING
331- auto input_tensor = FindTensorWithNameOrIndex (
199+
200+ int input_index = FindTensorIndex (
332201 GetInputTensors (), GetMetadataExtractor ()->GetInputTensorMetadata (),
333202 options.input_tensor_name , options.input_tensor_index );
334- if (input_tensor == nullptr ) {
203+
204+ if (input_index < 0 || input_index >= GetInputCount ()) {
335205 return CreateStatusWithPayload (
336206 StatusCode::kInvalidArgument ,
337207 absl::StrCat (" No input tensor found with name " ,
338208 options.input_tensor_name , " or at index " ,
339209 options.input_tensor_index ),
340210 TfLiteSupportStatus::kInputTensorNotFoundError );
341211 }
342- if (HasRegexTokenizerMetadata ()) {
343- if (input_tensor->type != kTfLiteInt32 ) {
344- return CreateStatusWithPayload (
345- StatusCode::kInvalidArgument ,
346- absl::StrCat (" Type mismatch for input tensor " , input_tensor->name ,
347- " . Requested INT32, got " ,
348- TfLiteTypeGetName (input_tensor->type ), " ." ),
349- TfLiteSupportStatus::kInvalidInputTensorTypeError );
350- }
351- RETURN_IF_ERROR (SetupRegexTokenizer ());
352- } else {
353- if (input_tensor->type != kTfLiteString ) {
354- return CreateStatusWithPayload (
355- StatusCode::kInvalidArgument ,
356- absl::StrCat (" Type mismatch for input tensor " , input_tensor->name ,
357- " . Requested STRING, got " ,
358- TfLiteTypeGetName (input_tensor->type ), " ." ),
359- TfLiteSupportStatus::kInvalidInputTensorTypeError );
360- }
361- }
212+
213+ // Create preprocessor.
214+ ASSIGN_OR_RETURN (preprocessor_, processor::TextPreprocessor::Create (
215+ GetTfLiteEngine (), {input_index}));
362216
363217 // output score tensor should be type
364218 // UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
@@ -480,35 +334,6 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
480334 return std::move (nl_classifier);
481335}
482336
483- bool NLClassifier::HasRegexTokenizerMetadata () {
484- const TensorMetadata* input_tensor_metadata =
485- GetMetadataExtractor ()->GetInputTensorMetadata (
486- kRegexTokenizerInputTensorIndex );
487- if (input_tensor_metadata == nullptr ) {
488- return false ;
489- }
490- tflite::support::StatusOr<const tflite::ProcessUnit*> status =
491- GetMetadataExtractor ()->FindFirstProcessUnit (
492- *input_tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions);
493- return status.ok () ? status.value () != nullptr : false ;
494- }
495-
496- absl::Status NLClassifier::SetupRegexTokenizer () {
497- ASSIGN_OR_RETURN (
498- std::unique_ptr<Tokenizer> base_tokenizer,
499- CreateRegexTokenizerFromProcessUnit (
500- GetMetadataExtractor ()
501- ->GetInputTensorMetadata (kRegexTokenizerInputTensorIndex )
502- ->process_units ()
503- ->Get (kRegexTokenizerProcessUnitIndex ),
504- GetMetadataExtractor ()));
505-
506- tokenizer_ = std::unique_ptr<RegexTokenizer>(
507- dynamic_cast <RegexTokenizer*>(base_tokenizer.release ()));
508-
509- return absl::OkStatus ();
510- }
511-
512337} // namespace nlclassifier
513338} // namespace text
514339} // namespace task
0 commit comments