Skip to content

Commit fac2a17

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Task delegate 10: refactor BertQuestionAnswerer
This is a preparation for integrating BaseOptions. Major changes include: (1) Switched to BertQuestionAnswerer::CreateFromOptions (the new API) in the JNI layer. (2) Properly returned error codes and messages from JNI. (3) Updated the Javadoc. PiperOrigin-RevId: 400354463
1 parent 293f2b9 commit fac2a17

File tree

2 files changed

+122
-81
lines changed

2 files changed

+122
-81
lines changed

tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java

Lines changed: 71 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,39 @@
2626
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
2727
import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
2828

29-
/** Task API for BertQA models. */
29+
/**
30+
* Returns the most possible answers on a given question for QA models (BERT, Albert, etc.).
31+
*
32+
* <p>The API expects a Bert based TFLite model with metadata containing the following information:
33+
*
34+
* <ul>
35+
* <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be used
36+
* for a <a
37+
* href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a> model,
38+
* Sentencepiece Tokenizer Tokenizer can be used for an <a
39+
* href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a> model.
40+
* <li>3 input tensors with names "ids", "mask" and "segment_ids".
41+
* <li>2 output tensors with names "end_logits" and "start_logits".
42+
* </ul>
43+
*/
3044
public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer {
3145
private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni";
32-
33-
private BertQuestionAnswerer(long nativeHandle) {
34-
super(nativeHandle);
35-
}
46+
private static final int OPTIONAL_FD_LENGTH = -1;
47+
private static final int OPTIONAL_FD_OFFSET = -1;
3648

3749
/**
38-
* Generic API to create the QuestionAnswerer for bert models with metadata populated. The API
39-
* expects a Bert based TFLite model with metadata containing the following information:
40-
*
41-
* <ul>
42-
* <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be
43-
* used for a <a
44-
* href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a>
45-
* model, Sentencepiece Tokenizer Tokenizer can be used for an <a
46-
* href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a>
47-
* model.
48-
* <li>3 input tensors with names "ids", "mask" and "segment_ids".
49-
* <li>2 output tensors with names "end_logits" and "start_logits".
50-
* </ul>
50+
* Creates a {@link BertQuestionAnswerer} instance from the default {@link
51+
* BertQuestionAnswererOptions}.
5152
*
5253
* @param context android context
53-
* @param pathToModel file path to the model with metadata. Note: The model should not be
54-
* compressed
55-
* @return {@link BertQuestionAnswerer} instance
56-
* @throws IOException If model file fails to load.
54+
* @param modelPath file path to the model with metadata. Note: The model should not be compressed
55+
* @return a {@link BertQuestionAnswerer} instance
56+
* @throws IOException if model file fails to load
57+
* @throws IllegalArgumentException if an argument is invalid
58+
* @throws IllegalStateException if there is an internal error
59+
* @throws RuntimeException if there is an otherwise unspecified error
5760
*/
58-
public static BertQuestionAnswerer createFromFile(Context context, String pathToModel)
61+
public static BertQuestionAnswerer createFromFile(Context context, String modelPath)
5962
throws IOException {
6063
return new BertQuestionAnswerer(
6164
TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
@@ -67,97 +70,98 @@ public long createHandle(ByteBuffer... buffers) {
6770
}
6871
},
6972
BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
70-
pathToModel));
73+
modelPath));
7174
}
7275

7376
/**
74-
* Generic API to create the QuestionAnswerer for bert models with metadata populated. The API
75-
* expects a Bert based TFLite model with metadata containing the following information:
76-
*
77-
* <ul>
78-
* <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be
79-
* used for a <a
80-
* href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a>
81-
* model, Sentencepiece Tokenizer Tokenizer can be used for an <a
82-
* href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a>
83-
* model.
84-
* <li>3 input tensors with names "ids", "mask" and "segment_ids".
85-
* <li>2 output tensors with names "end_logits" and "start_logits".
86-
* </ul>
77+
* Creates a {@link BertQuestionAnswerer} instance from the default {@link
78+
* BertQuestionAnswererOptions}.
8779
*
88-
* @param modelFile {@link File} object of the model
89-
* @return {@link BertQuestionAnswerer} instance
90-
* @throws IOException If model file fails to load.
80+
* @param modelFile a {@link File} object of the model
81+
* @return a {@link BertQuestionAnswerer} instance
82+
* @throws IOException if model file fails to load
83+
* @throws IllegalArgumentException if an argument is invalid
84+
* @throws IllegalStateException if there is an internal error
85+
* @throws RuntimeException if there is an otherwise unspecified error
9186
*/
92-
public static BertQuestionAnswerer createFromFile(File modelFile)
93-
throws IOException {
87+
public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException {
9488
try (ParcelFileDescriptor descriptor =
9589
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
9690
return new BertQuestionAnswerer(
9791
TaskJniUtils.createHandleFromLibrary(
9892
new EmptyHandleProvider() {
9993
@Override
10094
public long createHandle() {
101-
return initJniWithFileDescriptor(descriptor.getFd());
95+
return initJniWithFileDescriptor(
96+
/*fileDescriptor=*/ descriptor.getFd(),
97+
/*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
98+
/*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET);
10299
}
103100
},
104101
BERT_QUESTION_ANSWERER_NATIVE_LIBNAME));
105102
}
106103
}
107104

108105
/**
109-
* Creates the API instance with a bert model and vocabulary file.
106+
* Creates a {@link BertQuestionAnswerer} instance with a Bert model and a vocabulary file.
110107
*
111108
* <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
112109
*
113110
* @param context android context
114-
* @param pathToModel file path to the bert model. Note: The model should not be compressed
115-
* @param pathToVocab file path to the vocabulary file. Note: The file should not be compressed
116-
* @return {@link BertQuestionAnswerer} instance
117-
* @throws IOException If model file fails to load.
111+
* @param modelPath file path to the Bert model. Note: The model should not be compressed
112+
* @param vocabPath file path to the vocabulary file. Note: The file should not be compressed
113+
* @return a {@link BertQuestionAnswerer} instance
114+
* @throws IOException If model file fails to load
115+
* @throws IllegalArgumentException if an argument is invalid
116+
* @throws IllegalStateException if there is an internal error
117+
* @throws RuntimeException if there is an otherwise unspecified error
118118
*/
119119
public static BertQuestionAnswerer createBertQuestionAnswererFromFile(
120-
Context context, String pathToModel, String pathToVocab) throws IOException {
120+
Context context, String modelPath, String vocabPath) throws IOException {
121121
return new BertQuestionAnswerer(
122122
TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
123123
context,
124124
new MultipleBuffersHandleProvider() {
125125
@Override
126126
public long createHandle(ByteBuffer... buffers) {
127-
return BertQuestionAnswerer.initJniWithBertByteBuffers(buffers);
127+
return initJniWithBertByteBuffers(buffers);
128128
}
129129
},
130130
BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
131-
pathToModel,
132-
pathToVocab));
131+
modelPath,
132+
vocabPath));
133133
}
134134

135135
/**
136-
* Creates the API instance with an albert model and sentence piece model file.
136+
* Creates a {@link BertQuestionAnswerer} instance with an Albert model and a sentence piece model
137+
* file.
137138
*
138139
* <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
139140
*
140141
* @param context android context
141-
* @param pathToModel file path to the albert model. Note: The model should not be compressed
142-
* @param pathToSentencePieceModel file path to the sentence piece model file. Note: The model
142+
* @param modelPath file path to the Albert model. Note: The model should not be compressed
143+
* @param sentencePieceModelPath file path to the sentence piece model file. Note: The model
143144
* should not be compressed
144-
* @return {@link BertQuestionAnswerer} instance
145-
* @throws IOException If model file fails to load.
145+
* @return a {@link BertQuestionAnswerer} instance
146+
* @throws IOException If model file fails to load
147+
* @throws IllegalArgumentException if an argument is invalid
148+
* @throws IllegalStateException if there is an internal error
149+
* @throws RuntimeException if there is an otherwise unspecified error
146150
*/
147151
public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile(
148-
Context context, String pathToModel, String pathToSentencePieceModel) throws IOException {
152+
Context context, String modelPath, String sentencePieceModelPath) throws IOException {
149153
return new BertQuestionAnswerer(
150154
TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
151155
context,
152156
new MultipleBuffersHandleProvider() {
153157
@Override
154158
public long createHandle(ByteBuffer... buffers) {
155-
return BertQuestionAnswerer.initJniWithAlbertByteBuffers(buffers);
159+
return initJniWithAlbertByteBuffers(buffers);
156160
}
157161
},
158162
BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
159-
pathToModel,
160-
pathToSentencePieceModel));
163+
modelPath,
164+
sentencePieceModelPath));
161165
}
162166

163167
@Override
@@ -166,6 +170,10 @@ public List<QaAnswer> answer(String context, String question) {
166170
return answerNative(getNativeHandle(), context, question);
167171
}
168172

173+
private BertQuestionAnswerer(long nativeHandle) {
174+
super(nativeHandle);
175+
}
176+
169177
// modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
170178
private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
171179

@@ -176,7 +184,8 @@ public List<QaAnswer> answer(String context, String question) {
176184
// modelBuffers[0] is tflite model file buffer with metadata to specify which tokenizer to use.
177185
private static native long initJniWithModelWithMetadataByteBuffers(ByteBuffer... modelBuffers);
178186

179-
private static native long initJniWithFileDescriptor(int fd);
187+
private static native long initJniWithFileDescriptor(
188+
int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset);
180189

181190
private static native List<QaAnswer> answerNative(
182191
long nativeHandle, String context, String question);

tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,21 @@ limitations under the License.
2020

2121
namespace {
2222

23+
using ::tflite::support::StatusOr;
2324
using ::tflite::support::utils::ConvertVectorToArrayList;
25+
using ::tflite::support::utils::GetExceptionClassNameForStatusCode;
2426
using ::tflite::support::utils::GetMappedFileBuffer;
2527
using ::tflite::support::utils::JStringToString;
28+
using ::tflite::support::utils::ThrowException;
2629
using ::tflite::task::text::BertQuestionAnswerer;
30+
using ::tflite::task::text::BertQuestionAnswererOptions;
2731
using ::tflite::task::text::QaAnswer;
2832
using ::tflite::task::text::QuestionAnswerer;
2933

3034
constexpr int kInvalidPointer = 0;
3135

36+
} // namespace
37+
3238
extern "C" JNIEXPORT void JNICALL
3339
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni(
3440
JNIEnv* env, jobject thiz, jlong native_handle) {
@@ -41,24 +47,44 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithModelWithM
4147
absl::string_view model_with_metadata =
4248
GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
4349

44-
tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
45-
BertQuestionAnswerer::CreateFromBuffer(
46-
model_with_metadata.data(), model_with_metadata.size());
47-
if (status.ok()) {
48-
return reinterpret_cast<jlong>(status->release());
50+
tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> qa_status =
51+
BertQuestionAnswerer::CreateFromBuffer(model_with_metadata.data(),
52+
model_with_metadata.size());
53+
if (qa_status.ok()) {
54+
return reinterpret_cast<jlong>(qa_status->release());
4955
} else {
56+
ThrowException(
57+
env, GetExceptionClassNameForStatusCode(qa_status.status().code()),
58+
"Error occurred when initializing BertQuestionAnswerer: %s",
59+
qa_status.status().message().data());
5060
return kInvalidPointer;
5161
}
5262
}
53-
5463
extern "C" JNIEXPORT jlong JNICALL
5564
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor(
56-
JNIEnv* env, jclass thiz, jint fd) {
57-
tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
58-
BertQuestionAnswerer::CreateFromFd(fd);
59-
if (status.ok()) {
60-
return reinterpret_cast<jlong>(status->release());
65+
JNIEnv* env, jclass thiz, jint file_descriptor,
66+
jlong file_descriptor_length, jlong file_descriptor_offset) {
67+
BertQuestionAnswererOptions proto_options;
68+
auto file_descriptor_meta = proto_options.mutable_base_options()
69+
->mutable_model_file()
70+
->mutable_file_descriptor_meta();
71+
file_descriptor_meta->set_fd(file_descriptor);
72+
if (file_descriptor_length > 0) {
73+
file_descriptor_meta->set_length(file_descriptor_length);
74+
}
75+
if (file_descriptor_offset > 0) {
76+
file_descriptor_meta->set_offset(file_descriptor_offset);
77+
}
78+
79+
StatusOr<std::unique_ptr<QuestionAnswerer>> qa_status =
80+
BertQuestionAnswerer::CreateFromOptions(proto_options);
81+
if (qa_status.ok()) {
82+
return reinterpret_cast<jlong>(qa_status->release());
6183
} else {
84+
ThrowException(
85+
env, GetExceptionClassNameForStatusCode(qa_status.status().code()),
86+
"Error occurred when initializing BertQuestionAnswerer: %s",
87+
qa_status.status().message().data());
6288
return kInvalidPointer;
6389
}
6490
}
@@ -71,12 +97,16 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithBertByteBu
7197
absl::string_view vocab =
7298
GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1));
7399

74-
tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
100+
StatusOr<std::unique_ptr<QuestionAnswerer>> qa_status =
75101
BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
76102
model.data(), model.size(), vocab.data(), vocab.size());
77-
if (status.ok()) {
78-
return reinterpret_cast<jlong>(status->release());
103+
if (qa_status.ok()) {
104+
return reinterpret_cast<jlong>(qa_status->release());
79105
} else {
106+
ThrowException(
107+
env, GetExceptionClassNameForStatusCode(qa_status.status().code()),
108+
"Error occurred when initializing BertQuestionAnswerer: %s",
109+
qa_status.status().message().data());
80110
return kInvalidPointer;
81111
}
82112
}
@@ -89,12 +119,16 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithAlbertByte
89119
absl::string_view sp_model =
90120
GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 1));
91121

92-
tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> status =
122+
StatusOr<std::unique_ptr<QuestionAnswerer>> qa_status =
93123
BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
94124
model.data(), model.size(), sp_model.data(), sp_model.size());
95-
if (status.ok()) {
96-
return reinterpret_cast<jlong>(status->release());
125+
if (qa_status.ok()) {
126+
return reinterpret_cast<jlong>(qa_status->release());
97127
} else {
128+
ThrowException(
129+
env, GetExceptionClassNameForStatusCode(qa_status.status().code()),
130+
"Error occurred when initializing BertQuestionAnswerer: %s",
131+
qa_status.status().message().data());
98132
return kInvalidPointer;
99133
}
100134
}
@@ -123,5 +157,3 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_answerNative(
123157
return qa_answer;
124158
});
125159
}
126-
127-
} // namespace

0 commit comments

Comments
 (0)