Skip to content

Commit 15522f4

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Task delegate 11: Expose BaseOptions in Java BertQuestionAnswer
Major changes: (1) Created the `BertQuestionAnswererOptions` class and integrated it with BaseOptions. (2) Added factory create methods that expose `BertQuestionAnswererOptions`. (3) Unified the JNI method to create the C++ task using `initJniWithFileDescriptor`. Removed `initJniWithModelWithMetadataByteBuffers`. PiperOrigin-RevId: 400357315
1 parent ac27e1f commit 15522f4

File tree

4 files changed

+100
-34
lines changed

4 files changed

+100
-34
lines changed

tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ android_library(
3333
javacopts = ["-source 7 -target 7"],
3434
deps = [
3535
"//tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core:base_task_api",
36+
"@com_google_auto_value",
3637
"@org_tensorflow//tensorflow/lite/java:tensorflowlite_java",
3738
],
3839
)

tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717

1818
import android.content.Context;
1919
import android.os.ParcelFileDescriptor;
20+
import com.google.auto.value.AutoValue;
2021
import java.io.File;
2122
import java.io.IOException;
2223
import java.nio.ByteBuffer;
2324
import java.util.List;
25+
import org.tensorflow.lite.task.core.BaseOptions;
2426
import org.tensorflow.lite.task.core.BaseTaskApi;
2527
import org.tensorflow.lite.task.core.TaskJniUtils;
2628
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
29+
import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
2730
import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
2831

2932
/**
@@ -60,22 +63,62 @@ public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswere
6063
*/
6164
public static BertQuestionAnswerer createFromFile(Context context, String modelPath)
6265
throws IOException {
66+
return createFromFileAndOptions(
67+
context, modelPath, BertQuestionAnswererOptions.builder().build());
68+
}
69+
70+
/**
71+
* Creates a {@link BertQuestionAnswerer} instance from the default {@link
72+
* BertQuestionAnswererOptions}.
73+
*
74+
* @param modelFile a {@link File} object of the model
75+
* @return a {@link BertQuestionAnswerer} instance
76+
* @throws IOException if model file fails to load
77+
* @throws IllegalArgumentException if an argument is invalid
78+
* @throws IllegalStateException if there is an internal error
79+
* @throws RuntimeException if there is an otherwise unspecified error
80+
*/
81+
public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException {
82+
return createFromFileAndOptions(modelFile, BertQuestionAnswererOptions.builder().build());
83+
}
84+
85+
/**
86+
* Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
87+
*
88+
* @param context android context
89+
* @param modelPath file path to the model with metadata. Note: The model should not be compressed
90+
* @return a {@link BertQuestionAnswerer} instance
91+
* @throws IOException if model file fails to load
92+
* @throws IllegalArgumentException if an argument is invalid
93+
* @throws IllegalStateException if there is an internal error
94+
* @throws RuntimeException if there is an otherwise unspecified error
95+
*/
96+
public static BertQuestionAnswerer createFromFileAndOptions(
97+
Context context, String modelPath, BertQuestionAnswererOptions options) throws IOException {
6398
return new BertQuestionAnswerer(
64-
TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
99+
TaskJniUtils.createHandleFromFdAndOptions(
65100
context,
66-
new MultipleBuffersHandleProvider() {
101+
new FdAndOptionsHandleProvider<BertQuestionAnswererOptions>() {
67102
@Override
68-
public long createHandle(ByteBuffer... buffers) {
69-
return BertQuestionAnswerer.initJniWithModelWithMetadataByteBuffers(buffers);
103+
public long createHandle(
104+
int fileDescriptor,
105+
long fileDescriptorLength,
106+
long fileDescriptorOffset,
107+
BertQuestionAnswererOptions options) {
108+
return initJniWithFileDescriptor(
109+
fileDescriptor,
110+
fileDescriptorLength,
111+
fileDescriptorOffset,
112+
TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
70113
}
71114
},
72115
BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
73-
modelPath));
116+
modelPath,
117+
options));
74118
}
75119

76120
/**
77-
* Creates a {@link BertQuestionAnswerer} instance from the default {@link
78-
* BertQuestionAnswererOptions}.
121+
* Creates a {@link BertQuestionAnswerer} instance from {@link BertQuestionAnswererOptions}.
79122
*
80123
* @param modelFile a {@link File} object of the model
81124
* @return a {@link BertQuestionAnswerer} instance
@@ -84,7 +127,8 @@ public long createHandle(ByteBuffer... buffers) {
84127
* @throws IllegalStateException if there is an internal error
85128
* @throws RuntimeException if there is an otherwise unspecified error
86129
*/
87-
public static BertQuestionAnswerer createFromFile(File modelFile) throws IOException {
130+
public static BertQuestionAnswerer createFromFileAndOptions(
131+
File modelFile, final BertQuestionAnswererOptions options) throws IOException {
88132
try (ParcelFileDescriptor descriptor =
89133
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
90134
return new BertQuestionAnswerer(
@@ -95,7 +139,8 @@ public long createHandle() {
95139
return initJniWithFileDescriptor(
96140
/*fileDescriptor=*/ descriptor.getFd(),
97141
/*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
98-
/*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET);
142+
/*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
143+
TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
99144
}
100145
},
101146
BERT_QUESTION_ANSWERER_NATIVE_LIBNAME));
@@ -164,6 +209,26 @@ public long createHandle(ByteBuffer... buffers) {
164209
sentencePieceModelPath));
165210
}
166211

212+
/** Options for setting up a {@link BertQuestionAnswerer}. */
213+
@AutoValue
214+
public abstract static class BertQuestionAnswererOptions {
215+
abstract BaseOptions getBaseOptions();
216+
217+
public static Builder builder() {
218+
return new AutoValue_BertQuestionAnswerer_BertQuestionAnswererOptions.Builder()
219+
.setBaseOptions(BaseOptions.builder().build());
220+
}
221+
222+
/** Builder for {@link BertQuestionAnswererOptions}. */
223+
@AutoValue.Builder
224+
public abstract static class Builder {
225+
/** Sets the general options to configure Task APIs, such as accelerators. */
226+
public abstract Builder setBaseOptions(BaseOptions baseOptions);
227+
228+
public abstract BertQuestionAnswererOptions build();
229+
}
230+
}
231+
167232
@Override
168233
public List<QaAnswer> answer(String context, String question) {
169234
checkNotClosed();
@@ -181,11 +246,11 @@ private BertQuestionAnswerer(long nativeHandle) {
181246
// buffer.
182247
private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers);
183248

184-
// modelBuffers[0] is tflite model file buffer with metadata to specify which tokenizer to use.
185-
private static native long initJniWithModelWithMetadataByteBuffers(ByteBuffer... modelBuffers);
186-
187249
private static native long initJniWithFileDescriptor(
188-
int fileDescriptor, long fileDescriptorLength, long fileDescriptorOffset);
250+
int fileDescriptor,
251+
long fileDescriptorLength,
252+
long fileDescriptorOffset,
253+
long baseOptionsHandle);
189254

190255
private static native List<QaAnswer> answerNative(
191256
long nativeHandle, String context, String question);

tensorflow_lite_support/java/src/native/task/text/qa/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@ jni_binary_with_tflite(
1313
name = "libtask_text_jni.so",
1414
srcs = [
1515
"bert_question_answerer_jni.cc",
16+
"//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc",
1617
],
1718
linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
1819
tflite_deps = [
1920
"//tensorflow_lite_support/cc/task/text:bert_question_answerer",
2021
"//tensorflow_lite_support/cc/utils:jni_utils",
2122
],
2223
deps = [
24+
"//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc",
2325
"//tensorflow_lite_support/java/jni",
2426
],
2527
)

tensorflow_lite_support/java/src/native/task/text/qa/bert_question_answerer_jni.cc

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include <jni.h>
1717

18+
#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h"
1819
#include "tensorflow_lite_support/cc/task/text/bert_question_answerer.h"
1920
#include "tensorflow_lite_support/cc/utils/jni_utils.h"
2021

@@ -26,13 +27,27 @@ using ::tflite::support::utils::GetExceptionClassNameForStatusCode;
2627
using ::tflite::support::utils::GetMappedFileBuffer;
2728
using ::tflite::support::utils::JStringToString;
2829
using ::tflite::support::utils::ThrowException;
30+
using ::tflite::task::core::BaseOptions;
2931
using ::tflite::task::text::BertQuestionAnswerer;
3032
using ::tflite::task::text::BertQuestionAnswererOptions;
3133
using ::tflite::task::text::QaAnswer;
3234
using ::tflite::task::text::QuestionAnswerer;
3335

3436
constexpr int kInvalidPointer = 0;
3537

38+
// Creates a BertQuestionAnswererOptions proto based on the Java class.
39+
BertQuestionAnswererOptions ConvertToProtoOptions(jlong base_options_handle) {
40+
BertQuestionAnswererOptions proto_options;
41+
42+
if (base_options_handle != kInvalidPointer) {
43+
// proto_options will free the previous base_options and set the new one.
44+
proto_options.set_allocated_base_options(
45+
reinterpret_cast<BaseOptions*>(base_options_handle));
46+
}
47+
48+
return proto_options;
49+
}
50+
3651
} // namespace
3752

3853
extern "C" JNIEXPORT void JNICALL
@@ -41,30 +56,13 @@ Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_deinitJni(
4156
delete reinterpret_cast<QuestionAnswerer*>(native_handle);
4257
}
4358

44-
extern "C" JNIEXPORT jlong JNICALL
45-
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithModelWithMetadataByteBuffers(
46-
JNIEnv* env, jclass thiz, jobjectArray model_buffers) {
47-
absl::string_view model_with_metadata =
48-
GetMappedFileBuffer(env, env->GetObjectArrayElement(model_buffers, 0));
49-
50-
tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>> qa_status =
51-
BertQuestionAnswerer::CreateFromBuffer(model_with_metadata.data(),
52-
model_with_metadata.size());
53-
if (qa_status.ok()) {
54-
return reinterpret_cast<jlong>(qa_status->release());
55-
} else {
56-
ThrowException(
57-
env, GetExceptionClassNameForStatusCode(qa_status.status().code()),
58-
"Error occurred when initializing BertQuestionAnswerer: %s",
59-
qa_status.status().message().data());
60-
return kInvalidPointer;
61-
}
62-
}
6359
extern "C" JNIEXPORT jlong JNICALL
6460
Java_org_tensorflow_lite_task_text_qa_BertQuestionAnswerer_initJniWithFileDescriptor(
6561
JNIEnv* env, jclass thiz, jint file_descriptor,
66-
jlong file_descriptor_length, jlong file_descriptor_offset) {
67-
BertQuestionAnswererOptions proto_options;
62+
jlong file_descriptor_length, jlong file_descriptor_offset,
63+
jlong base_options_handle) {
64+
BertQuestionAnswererOptions proto_options =
65+
ConvertToProtoOptions(base_options_handle);
6866
auto file_descriptor_meta = proto_options.mutable_base_options()
6967
->mutable_model_file()
7068
->mutable_file_descriptor_meta();

0 commit comments

Comments
 (0)