@@ -35,6 +35,8 @@ limitations under the License.
3535#include " tensorflow_lite_support/cc/task/core/category.h"
3636#include " tensorflow_lite_support/cc/task/core/task_api_factory.h"
3737#include " tensorflow_lite_support/cc/task/core/task_utils.h"
38+ #include " tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
39+ #include " tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
3840#include " tensorflow_lite_support/cc/utils/common_utils.h"
3941
4042namespace tflite {
@@ -49,16 +51,22 @@ using ::tflite::TensorMetadata;
4951using ::tflite::support::CreateStatusWithPayload;
5052using ::tflite::support::StatusOr;
5153using ::tflite::support::TfLiteSupportStatus;
54+ using ::tflite::support::text::tokenizer::RegexTokenizer;
55+ using ::tflite::support::text::tokenizer::Tokenizer;
56+ using ::tflite::support::text::tokenizer::TokenizerResult;
5257using ::tflite::support::utils::LoadVocabFromBuffer;
5358using ::tflite::task::core::Category;
5459using ::tflite::task::core::Dequantize;
5560using ::tflite::task::core::GetStringAtIndex;
61+ using ::tflite::task::core::PopulateTensor;
5662using ::tflite::task::core::TaskAPIFactory;
5763// To differenciate it with the struct option,
5864// tflite::task::text::nl_classifier::NLClassifierOptions.
5965using NLClassifierProtoOptions = ::tflite::task::text::NLClassifierOptions;
6066
6167namespace {
68+ constexpr int kRegexTokenizerInputTensorIndex = 0 ;
69+ constexpr int kRegexTokenizerProcessUnitIndex = 0 ;
6270
6371absl::Status SanityCheckOptions (const NLClassifierProtoOptions& options) {
6472 if (!options.has_base_options ()) {
@@ -69,6 +77,78 @@ absl::Status SanityCheckOptions(const NLClassifierProtoOptions& options) {
6977 return absl::OkStatus ();
7078}
7179
80+ StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile (
81+ const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
82+ associated_files,
83+ const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
84+ if (associated_files == nullptr || associated_files->size () < 1 ||
85+ associated_files->Get (0 )->name () == nullptr ) {
86+ return CreateStatusWithPayload (
87+ absl::StatusCode::kInvalidArgument ,
88+ " Invalid vocab_file from input process unit." ,
89+ TfLiteSupportStatus::kMetadataInvalidTokenizerError );
90+ }
91+ ASSIGN_OR_RETURN (absl::string_view vocab_buffer,
92+ metadata_extractor->GetAssociatedFile (
93+ associated_files->Get (0 )->name ()->str ()));
94+ return vocab_buffer;
95+ }
96+
97+ StatusOr<std::unique_ptr<Tokenizer>> CreateRegexTokenizerFromProcessUnit (
98+ const tflite::ProcessUnit* tokenizer_process_unit,
99+ const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
100+ if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr ) {
101+ return CreateStatusWithPayload (
102+ absl::StatusCode::kInvalidArgument ,
103+ " No metadata or input process unit found." ,
104+ TfLiteSupportStatus::kMetadataInvalidTokenizerError );
105+ }
106+
107+ if (tokenizer_process_unit->options_type () !=
108+ ProcessUnitOptions_RegexTokenizerOptions) {
109+ return CreateStatusWithPayload (
110+ absl::StatusCode::kNotFound ,
111+ absl::StrCat (
112+ " Incorrect options_type:" , tokenizer_process_unit->options_type (),
113+ " need RegexTokenizerOptions." ),
114+ TfLiteSupportStatus::kMetadataInvalidTokenizerError );
115+ }
116+
117+ const tflite::RegexTokenizerOptions* options =
118+ tokenizer_process_unit->options_as <RegexTokenizerOptions>();
119+ ASSIGN_OR_RETURN (absl::string_view vocab_buffer,
120+ CheckAndLoadFirstAssociatedFile (options->vocab_file (),
121+ metadata_extractor));
122+ if (options->delim_regex_pattern () == nullptr ) {
123+ return CreateStatusWithPayload (
124+ absl::StatusCode::kInvalidArgument ,
125+ " Invalid delim_regex_pattern from input process unit." ,
126+ TfLiteSupportStatus::kMetadataInvalidTokenizerError );
127+ }
128+
129+ std::unique_ptr<RegexTokenizer> regex_tokenizer =
130+ absl::make_unique<RegexTokenizer>(options->delim_regex_pattern ()->str (),
131+ vocab_buffer.data (),
132+ vocab_buffer.size ());
133+
134+ int unknown_token_id = 0 ;
135+ if (!regex_tokenizer->GetUnknownToken (&unknown_token_id)) {
136+ return CreateStatusWithPayload (
137+ absl::StatusCode::kInvalidArgument ,
138+ " RegexTokenizer doesn't have <UNKNOWN> token." ,
139+ TfLiteSupportStatus::kMetadataInvalidTokenizerError );
140+ }
141+
142+ int pad_token_id = 0 ;
143+ if (!regex_tokenizer->GetPadToken (&pad_token_id)) {
144+ return CreateStatusWithPayload (
145+ absl::StatusCode::kInvalidArgument ,
146+ " RegexTokenizer doesn't have <PAD> token." ,
147+ TfLiteSupportStatus::kMetadataInvalidTokenizerError );
148+ }
149+ return std::move (regex_tokenizer);
150+ }
151+
72152} // namespace
73153
74154const NLClassifierOptions& NLClassifier::GetOptions () const {
@@ -121,7 +201,58 @@ std::vector<Category> NLClassifier::Classify(const std::string& text) {
121201
122202absl::Status NLClassifier::Preprocess (
123203 const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
124- return preprocessor_->Preprocess (input);
204+ TfLiteTensor* input_tensor = FindTensorWithNameOrIndex (
205+ input_tensors, GetMetadataExtractor ()->GetInputTensorMetadata (),
206+ struct_options_.input_tensor_name , struct_options_.input_tensor_index );
207+ if (input_tensor == nullptr ) {
208+ return CreateStatusWithPayload (
209+ absl::StatusCode::kInvalidArgument ,
210+ " No input tensor found from NLClassifierOptions." ,
211+ TfLiteSupportStatus::kInputTensorNotFoundError );
212+ }
213+
214+ if (HasRegexTokenizerMetadata ()) {
215+ // |<-------sentence_length-------->|
216+ // input_tensor <START>, t1, t2... <PAD>, <PAD>...
217+ // <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
218+ // found in tokenizer vocab.
219+ TokenizerResult result = tokenizer_->Tokenize (input);
220+
221+ size_t max_sentence_length = input_tensor->dims ->size == 2
222+ ? input_tensor->dims ->data [1 ]
223+ : input_tensor->dims ->data [0 ];
224+
225+ int unknown_token_id = 0 ;
226+ tokenizer_->GetUnknownToken (&unknown_token_id);
227+
228+ int pad_token_id = 0 ;
229+ tokenizer_->GetPadToken (&pad_token_id);
230+
231+ std::vector<int > input_tokens (max_sentence_length, pad_token_id);
232+ int start_token_id = 0 ;
233+ size_t input_token_index = 0 ;
234+ if (tokenizer_->GetStartToken (&start_token_id)) {
235+ input_tokens[0 ] = start_token_id;
236+ input_token_index = 1 ;
237+ }
238+
239+ for (size_t i = 0 ; (i < result.subwords .size ()) &&
240+ (input_token_index < max_sentence_length);
241+ ++i, ++input_token_index) {
242+ const std::string& token = result.subwords [i];
243+ int token_id = 0 ;
244+ if (tokenizer_->LookupId (token, &token_id)) {
245+ input_tokens[input_token_index] = token_id;
246+ } else {
247+ input_tokens[input_token_index] = unknown_token_id;
248+ }
249+ }
250+
251+ RETURN_IF_ERROR (PopulateTensor (input_tokens, input_tensor));
252+ } else {
253+ RETURN_IF_ERROR (PopulateTensor (input, input_tensor));
254+ }
255+ return absl::OkStatus ();
125256}
126257
127258StatusOr<std::vector<Category>> NLClassifier::Postprocess (
@@ -196,23 +327,38 @@ absl::Status NLClassifier::Initialize(
196327
197328absl::Status NLClassifier::Initialize (const NLClassifierOptions& options) {
198329 struct_options_ = options;
199-
200- int input_index = FindTensorIndex (
330+ // input tensor should be type STRING
331+ auto input_tensor = FindTensorWithNameOrIndex (
201332 GetInputTensors (), GetMetadataExtractor ()->GetInputTensorMetadata (),
202333 options.input_tensor_name , options.input_tensor_index );
203-
204- if (input_index < 0 || input_index >= GetInputCount ()) {
334+ if (input_tensor == nullptr ) {
205335 return CreateStatusWithPayload (
206336 StatusCode::kInvalidArgument ,
207337 absl::StrCat (" No input tensor found with name " ,
208338 options.input_tensor_name , " or at index " ,
209339 options.input_tensor_index ),
210340 TfLiteSupportStatus::kInputTensorNotFoundError );
211341 }
212-
213- // Create preprocessor.
214- ASSIGN_OR_RETURN (preprocessor_, processor::TextPreprocessor::Create (
215- GetTfLiteEngine (), {input_index}));
342+ if (HasRegexTokenizerMetadata ()) {
343+ if (input_tensor->type != kTfLiteInt32 ) {
344+ return CreateStatusWithPayload (
345+ StatusCode::kInvalidArgument ,
346+ absl::StrCat (" Type mismatch for input tensor " , input_tensor->name ,
347+ " . Requested INT32, got " ,
348+ TfLiteTypeGetName (input_tensor->type ), " ." ),
349+ TfLiteSupportStatus::kInvalidInputTensorTypeError );
350+ }
351+ RETURN_IF_ERROR (SetupRegexTokenizer ());
352+ } else {
353+ if (input_tensor->type != kTfLiteString ) {
354+ return CreateStatusWithPayload (
355+ StatusCode::kInvalidArgument ,
356+ absl::StrCat (" Type mismatch for input tensor " , input_tensor->name ,
357+ " . Requested STRING, got " ,
358+ TfLiteTypeGetName (input_tensor->type ), " ." ),
359+ TfLiteSupportStatus::kInvalidInputTensorTypeError );
360+ }
361+ }
216362
217363 // output score tensor should be type
218364 // UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
@@ -334,6 +480,35 @@ StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
334480 return std::move (nl_classifier);
335481}
336482
483+ bool NLClassifier::HasRegexTokenizerMetadata () {
484+ const TensorMetadata* input_tensor_metadata =
485+ GetMetadataExtractor ()->GetInputTensorMetadata (
486+ kRegexTokenizerInputTensorIndex );
487+ if (input_tensor_metadata == nullptr ) {
488+ return false ;
489+ }
490+ tflite::support::StatusOr<const tflite::ProcessUnit*> status =
491+ GetMetadataExtractor ()->FindFirstProcessUnit (
492+ *input_tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions);
493+ return status.ok () ? status.value () != nullptr : false ;
494+ }
495+
496+ absl::Status NLClassifier::SetupRegexTokenizer () {
497+ ASSIGN_OR_RETURN (
498+ std::unique_ptr<Tokenizer> base_tokenizer,
499+ CreateRegexTokenizerFromProcessUnit (
500+ GetMetadataExtractor ()
501+ ->GetInputTensorMetadata (kRegexTokenizerInputTensorIndex )
502+ ->process_units ()
503+ ->Get (kRegexTokenizerProcessUnitIndex ),
504+ GetMetadataExtractor ()));
505+
506+ tokenizer_ = std::unique_ptr<RegexTokenizer>(
507+ dynamic_cast <RegexTokenizer*>(base_tokenizer.release ()));
508+
509+ return absl::OkStatus ();
510+ }
511+
337512} // namespace nlclassifier
338513} // namespace text
339514} // namespace task
0 commit comments