Skip to content

Commit 6f25185

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
TextPreprocesor 1: TextPreprocesor for RegexTokenizer
Extracted out the text preprocessing logic from NLClassifier and created the TextPreprocessor for RegexTokenizer. PiperOrigin-RevId: 410028848
1 parent 4a88b96 commit 6f25185

File tree

4 files changed

+316
-2
lines changed

4 files changed

+316
-2
lines changed

tensorflow_lite_support/cc/task/processor/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library_with_tflite(
1515
],
1616
deps = [
1717
"//tensorflow_lite_support/cc:common",
18+
"//tensorflow_lite_support/cc/port:status_macros",
1819
"//tensorflow_lite_support/cc/port:statusor",
1920
"@com_google_absl//absl/status",
2021
"@com_google_absl//absl/strings:str_format",
@@ -98,3 +99,25 @@ cc_library_with_tflite(
9899
"@com_google_absl//absl/strings:str_format",
99100
],
100101
)
102+
103+
cc_library_with_tflite(
104+
name = "text_preprocessor",
105+
srcs = ["text_preprocessor.cc"],
106+
hdrs = ["text_preprocessor.h"],
107+
tflite_deps = [
108+
":processor",
109+
"//tensorflow_lite_support/cc/task/core:tflite_engine",
110+
],
111+
deps = [
112+
"//tensorflow_lite_support/cc:common",
113+
"//tensorflow_lite_support/cc/port:status_macros",
114+
"//tensorflow_lite_support/cc/port:statusor",
115+
"//tensorflow_lite_support/cc/task/core:task_utils",
116+
"//tensorflow_lite_support/cc/text/tokenizers:regex_tokenizer",
117+
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer",
118+
"//tensorflow_lite_support/cc/utils:common_utils",
119+
"@com_google_absl//absl/memory",
120+
"@com_google_absl//absl/status",
121+
"@com_google_absl//absl/strings:str_format",
122+
],
123+
)

tensorflow_lite_support/cc/task/processor/processor.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include "absl/strings/str_format.h" // from @com_google_absl
2323
#include "tensorflow/lite/core/shims/c/common.h"
2424
#include "tensorflow_lite_support/cc/common.h"
25+
#include "tensorflow_lite_support/cc/port/status_macros.h"
2526
#include "tensorflow_lite_support/cc/port/statusor.h"
2627
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
2728

@@ -88,6 +89,11 @@ class Processor {
8889
// `i` should be 1.
8990
virtual const tflite::TensorMetadata* GetTensorMetadata(int i = 0) const = 0;
9091

92+
inline const tflite::metadata::ModelMetadataExtractor* GetMetadataExtractor()
93+
const {
94+
return engine_->metadata_extractor();
95+
}
96+
9197
core::TfLiteEngine* engine_;
9298
const std::vector<int> tensor_indices_;
9399

@@ -127,7 +133,7 @@ class Preprocessor : public Processor {
127133
// Note: Caller is responsible for passing in a valid `i`.
128134
inline const tflite::TensorMetadata* GetTensorMetadata(
129135
int i = 0) const override {
130-
return engine_->metadata_extractor()->GetInputTensorMetadata(
136+
return GetMetadataExtractor()->GetInputTensorMetadata(
131137
tensor_indices_.at(i));
132138
}
133139

@@ -168,7 +174,7 @@ class Postprocessor : public Processor {
168174
// Note: Caller is responsible for passing in a valid `i`.
169175
inline const tflite::TensorMetadata* GetTensorMetadata(
170176
int i = 0) const override {
171-
return engine_->metadata_extractor()->GetOutputTensorMetadata(
177+
return GetMetadataExtractor()->GetOutputTensorMetadata(
172178
tensor_indices_.at(i));
173179
}
174180

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow_lite_support/cc/task/processor/text_preprocessor.h"
16+
17+
#include "absl/status/status.h" // from @com_google_absl
18+
#include "tensorflow_lite_support/cc/common.h"
19+
#include "tensorflow_lite_support/cc/port/status_macros.h"
20+
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
21+
#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
22+
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
23+
#include "tensorflow_lite_support/cc/utils/common_utils.h"
24+
25+
namespace tflite {
26+
namespace task {
27+
namespace processor {
28+
29+
namespace {
30+
31+
using ::absl::StatusCode;
32+
using ::flatbuffers::Offset;
33+
using ::flatbuffers::Vector;
34+
using ::tflite::TensorMetadata;
35+
using ::tflite::support::CreateStatusWithPayload;
36+
using ::tflite::support::StatusOr;
37+
using ::tflite::support::TfLiteSupportStatus;
38+
using ::tflite::support::text::tokenizer::RegexTokenizer;
39+
using ::tflite::support::text::tokenizer::Tokenizer;
40+
using ::tflite::support::text::tokenizer::TokenizerResult;
41+
using ::tflite::task::core::PopulateTensor;
42+
43+
constexpr int kRegexTokenizerProcessUnitIndex = 0;
44+
45+
StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
46+
const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
47+
associated_files,
48+
const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
49+
if (associated_files == nullptr || associated_files->size() < 1 ||
50+
associated_files->Get(0)->name() == nullptr) {
51+
return CreateStatusWithPayload(
52+
absl::StatusCode::kInvalidArgument,
53+
"Invalid vocab_file from input process unit.",
54+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
55+
}
56+
ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
57+
metadata_extractor->GetAssociatedFile(
58+
associated_files->Get(0)->name()->str()));
59+
return vocab_buffer;
60+
}
61+
62+
StatusOr<std::unique_ptr<Tokenizer>> CreateRegexTokenizerFromProcessUnit(
63+
const tflite::ProcessUnit* tokenizer_process_unit,
64+
const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
65+
if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) {
66+
return CreateStatusWithPayload(
67+
absl::StatusCode::kInvalidArgument,
68+
"No metadata or input process unit found.",
69+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
70+
}
71+
72+
if (tokenizer_process_unit->options_type() !=
73+
ProcessUnitOptions_RegexTokenizerOptions) {
74+
return CreateStatusWithPayload(
75+
absl::StatusCode::kNotFound,
76+
absl::StrCat(
77+
"Incorrect options_type:", tokenizer_process_unit->options_type(),
78+
" need RegexTokenizerOptions."),
79+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
80+
}
81+
82+
const tflite::RegexTokenizerOptions* options =
83+
tokenizer_process_unit->options_as<RegexTokenizerOptions>();
84+
85+
ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
86+
CheckAndLoadFirstAssociatedFile(options->vocab_file(),
87+
metadata_extractor));
88+
89+
if (options->delim_regex_pattern() == nullptr) {
90+
return CreateStatusWithPayload(
91+
absl::StatusCode::kInvalidArgument,
92+
"Invalid delim_regex_pattern from input process unit.",
93+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
94+
}
95+
96+
std::unique_ptr<RegexTokenizer> regex_tokenizer =
97+
absl::make_unique<RegexTokenizer>(options->delim_regex_pattern()->str(),
98+
vocab_buffer.data(),
99+
vocab_buffer.size());
100+
101+
int unknown_token_id = 0;
102+
if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) {
103+
return CreateStatusWithPayload(
104+
absl::StatusCode::kInvalidArgument,
105+
"RegexTokenizer doesn't have <UNKNOWN> token.",
106+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
107+
}
108+
109+
int pad_token_id = 0;
110+
if (!regex_tokenizer->GetPadToken(&pad_token_id)) {
111+
return CreateStatusWithPayload(
112+
absl::StatusCode::kInvalidArgument,
113+
"RegexTokenizer doesn't have <PAD> token.",
114+
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
115+
}
116+
return std::move(regex_tokenizer);
117+
}
118+
119+
} // namespace
120+
121+
/* static */
122+
StatusOr<std::unique_ptr<TextPreprocessor>> TextPreprocessor::Create(
123+
tflite::task::core::TfLiteEngine* engine, int input_tensor_index) {
124+
ASSIGN_OR_RETURN(auto processor, Processor::Create<TextPreprocessor>(
125+
/* num_expected_tensors = */ 1, engine,
126+
{input_tensor_index},
127+
/* requires_metadata = */ false));
128+
RETURN_IF_ERROR(processor->Init());
129+
return processor;
130+
}
131+
132+
absl::Status TextPreprocessor::Init() {
133+
auto input_tensor = GetTensor();
134+
if (HasRegexTokenizerMetadata()) {
135+
if (input_tensor->type != kTfLiteInt32) {
136+
return CreateStatusWithPayload(
137+
StatusCode::kInvalidArgument,
138+
absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
139+
". Requested INT32, got ",
140+
TfLiteTypeGetName(input_tensor->type), "."),
141+
TfLiteSupportStatus::kInvalidInputTensorTypeError);
142+
}
143+
RETURN_IF_ERROR(SetupRegexTokenizer());
144+
} else {
145+
if (input_tensor->type != kTfLiteString) {
146+
return CreateStatusWithPayload(
147+
StatusCode::kInvalidArgument,
148+
absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
149+
". Requested STRING, got ",
150+
TfLiteTypeGetName(input_tensor->type), "."),
151+
TfLiteSupportStatus::kInvalidInputTensorTypeError);
152+
}
153+
}
154+
155+
return absl::OkStatus();
156+
}
157+
158+
absl::Status TextPreprocessor::Preprocess(const std::string& input_text) {
159+
TfLiteTensor* input_tensor = GetTensor();
160+
if (HasRegexTokenizerMetadata()) {
161+
// |<-------sentence_length-------->|
162+
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
163+
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
164+
// found in tokenizer vocab.
165+
TokenizerResult result = tokenizer_->Tokenize(input_text);
166+
167+
size_t max_sentence_length = input_tensor->dims->size == 2
168+
? input_tensor->dims->data[1]
169+
: input_tensor->dims->data[0];
170+
171+
int unknown_token_id = 0;
172+
tokenizer_->GetUnknownToken(&unknown_token_id);
173+
174+
int pad_token_id = 0;
175+
tokenizer_->GetPadToken(&pad_token_id);
176+
177+
std::vector<int> input_tokens(max_sentence_length, pad_token_id);
178+
int start_token_id = 0;
179+
size_t input_token_index = 0;
180+
if (tokenizer_->GetStartToken(&start_token_id)) {
181+
input_tokens[0] = start_token_id;
182+
input_token_index = 1;
183+
}
184+
185+
for (size_t i = 0; (i < result.subwords.size()) &&
186+
(input_token_index < max_sentence_length);
187+
++i, ++input_token_index) {
188+
const std::string& token = result.subwords[i];
189+
int token_id = 0;
190+
if (tokenizer_->LookupId(token, &token_id)) {
191+
input_tokens[input_token_index] = token_id;
192+
} else {
193+
input_tokens[input_token_index] = unknown_token_id;
194+
}
195+
}
196+
return PopulateTensor(input_tokens, input_tensor);
197+
} else {
198+
return PopulateTensor(input_text, input_tensor);
199+
}
200+
}
201+
202+
bool TextPreprocessor::HasRegexTokenizerMetadata() {
203+
const TensorMetadata* metadata = GetTensorMetadata();
204+
if (metadata == nullptr) {
205+
return false;
206+
}
207+
auto status_or = GetMetadataExtractor()->FindFirstProcessUnit(
208+
*metadata, ProcessUnitOptions_RegexTokenizerOptions);
209+
return status_or.ok() ? status_or.value() != nullptr : false;
210+
}
211+
212+
absl::Status TextPreprocessor::SetupRegexTokenizer() {
213+
ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer> base_tokenizer,
214+
CreateRegexTokenizerFromProcessUnit(
215+
GetTensorMetadata()->process_units()->Get(
216+
kRegexTokenizerProcessUnitIndex),
217+
GetMetadataExtractor()));
218+
219+
tokenizer_ = std::unique_ptr<RegexTokenizer>(
220+
dynamic_cast<RegexTokenizer*>(base_tokenizer.release()));
221+
222+
return absl::OkStatus();
223+
}
224+
225+
} // namespace processor
226+
} // namespace task
227+
} // namespace tflite
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_TEXT_PREPROCESSOR_H_
16+
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_TEXT_PREPROCESSOR_H_
17+
18+
#include "absl/status/status.h" // from @com_google_absl
19+
#include "tensorflow_lite_support/cc/port/statusor.h"
20+
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
21+
#include "tensorflow_lite_support/cc/task/processor/processor.h"
22+
#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
23+
24+
namespace tflite {
25+
namespace task {
26+
namespace processor {
27+
28+
// Processes input text and populates the associated input tensor.
29+
// Requirements for the input tensor:
30+
// (kTfLiteString) - input of the model, accepts a string.
31+
// or
32+
// (kTfLiteInt32) - input of the model, accepts a tokenized indices of a
33+
// string input. A RegexTokenizer needs to be set up in the input tensor's
34+
// metadata.
35+
class TextPreprocessor : public Preprocessor {
36+
public:
37+
static tflite::support::StatusOr<std::unique_ptr<TextPreprocessor>> Create(
38+
tflite::task::core::TfLiteEngine* engine, int input_tensor_index);
39+
40+
absl::Status Preprocess(const std::string& text);
41+
42+
private:
43+
using Preprocessor::Preprocessor;
44+
45+
absl::Status Init();
46+
47+
bool HasRegexTokenizerMetadata();
48+
49+
absl::Status SetupRegexTokenizer();
50+
51+
std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_;
52+
};
53+
54+
} // namespace processor
55+
} // namespace task
56+
} // namespace tflite
57+
58+
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_TEXT_PREPROCESSOR_H_

0 commit comments

Comments
 (0)