Skip to content

Commit ac27e1f

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Task delegate 12: Expose BaseOptions in Java BertNLClassifier
PiperOrigin-RevId: 400355143
1 parent fac2a17 commit ac27e1f

File tree

3 files changed

+42
-15
lines changed

3 files changed

+42
-15
lines changed

tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.List;
2626
import org.tensorflow.lite.annotations.UsedByReflection;
2727
import org.tensorflow.lite.support.label.Category;
28+
import org.tensorflow.lite.task.core.BaseOptions;
2829
import org.tensorflow.lite.task.core.BaseTaskApi;
2930
import org.tensorflow.lite.task.core.TaskJniUtils;
3031
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
@@ -54,14 +55,21 @@ public abstract static class BertNLClassifierOptions {
5455
@UsedByReflection("bert_nl_classifier_jni.cc")
5556
abstract int getMaxSeqLen();
5657

57-
public static BertNLClassifierOptions.Builder builder() {
58+
abstract BaseOptions getBaseOptions();
59+
60+
public static Builder builder() {
5861
return new AutoValue_BertNLClassifier_BertNLClassifierOptions.Builder()
59-
.setMaxSeqLen(DEFAULT_MAX_SEQ_LEN);
62+
.setMaxSeqLen(DEFAULT_MAX_SEQ_LEN)
63+
.setBaseOptions(BaseOptions.builder().build());
6064
}
6165

6266
/** Builder for {@link BertNLClassifierOptions}. */
6367
@AutoValue.Builder
6468
public abstract static class Builder {
69+
70+
/** Sets the general options to configure Task APIs, such as accelerators. */
71+
public abstract Builder setBaseOptions(BaseOptions baseOptions);
72+
6573
public abstract BertNLClassifierOptions.Builder setMaxSeqLen(int value);
6674

6775
public abstract BertNLClassifierOptions build();
@@ -139,7 +147,10 @@ public static BertNLClassifier createFromFileAndOptions(
139147
new EmptyHandleProvider() {
140148
@Override
141149
public long createHandle() {
142-
return initJniWithFileDescriptor(descriptor.getFd(), options);
150+
return initJniWithFileDescriptor(
151+
descriptor.getFd(),
152+
options,
153+
TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
143154
}
144155
},
145156
BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
@@ -179,7 +190,10 @@ public static BertNLClassifier createFromBufferAndOptions(
179190
new EmptyHandleProvider() {
180191
@Override
181192
public long createHandle() {
182-
return initJniWithByteBuffer(modelBuffer, options);
193+
return initJniWithByteBuffer(
194+
modelBuffer,
195+
options,
196+
TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
183197
}
184198
},
185199
BERT_NL_CLASSIFIER_NATIVE_LIBNAME));
@@ -196,9 +210,10 @@ public List<Category> classify(String text) {
196210
}
197211

198212
private static native long initJniWithByteBuffer(
199-
ByteBuffer modelBuffer, BertNLClassifierOptions options);
213+
ByteBuffer modelBuffer, BertNLClassifierOptions options, long baseOptionsHandle);
200214

201-
private static native long initJniWithFileDescriptor(int fd, BertNLClassifierOptions options);
215+
private static native long initJniWithFileDescriptor(
216+
int fd, BertNLClassifierOptions options, long baseOptionsHandle);
202217

203218
private static native List<Category> classifyNative(long nativeHandle, String text);
204219

tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ jni_binary_with_tflite(
2020
name = "libtask_text_jni.so",
2121
srcs = [
2222
"bert_nl_classifier_jni.cc",
23+
"//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc",
2324
],
2425
linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
2526
tflite_deps = [
@@ -28,6 +29,7 @@ jni_binary_with_tflite(
2829
"//tensorflow_lite_support/java/src/native/task/text/nlclassifier:nl_classifier_jni_utils",
2930
],
3031
deps = [
32+
"//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc",
3133
"//tensorflow_lite_support/cc/task/text/proto:bert_nl_classifier_options_proto_inc",
3234
"//tensorflow_lite_support/java/jni",
3335
],

tensorflow_lite_support/java/src/native/task/text/nlclassifier/bert/bert_nl_classifier_jni.cc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include <jni.h>
1717

18+
#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h"
1819
#include "tensorflow_lite_support/cc/task/text/bert_nl_classifier.h"
1920
#include "tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options_proto_inc.h"
2021
#include "tensorflow_lite_support/cc/utils/jni_utils.h"
@@ -25,13 +26,21 @@ namespace {
2526
using ::tflite::support::utils::kAssertionError;
2627
using ::tflite::support::utils::kInvalidPointer;
2728
using ::tflite::support::utils::ThrowException;
29+
using ::tflite::task::core::BaseOptions;
2830
using ::tflite::task::text::BertNLClassifier;
2931
using ::tflite::task::text::BertNLClassifierOptions;
3032
using ::tflite::task::text::nlclassifier::RunClassifier;
3133

3234
BertNLClassifierOptions ConvertJavaBertNLClassifierOptions(
33-
JNIEnv* env, jobject java_options) {
35+
JNIEnv* env, jobject java_options, jlong base_options_handle) {
3436
BertNLClassifierOptions proto_options;
37+
38+
if (base_options_handle != kInvalidPointer) {
39+
// proto_options will free the previous base_options and set the new one.
40+
proto_options.set_allocated_base_options(
41+
reinterpret_cast<BaseOptions*>(base_options_handle));
42+
}
43+
3544
jclass java_options_class = env->FindClass(
3645
"org/tensorflow/lite/task/text/nlclassifier/"
3746
"BertNLClassifier$BertNLClassifierOptions");
@@ -41,6 +50,7 @@ BertNLClassifierOptions ConvertJavaBertNLClassifierOptions(
4150
env->CallIntMethod(java_options, max_seq_len_id));
4251
return proto_options;
4352
}
53+
} // namespace
4454

4555
extern "C" JNIEXPORT void JNICALL
4656
Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni(
@@ -50,9 +60,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni(
5060

5161
extern "C" JNIEXPORT jlong JNICALL
5262
Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByteBuffer(
53-
JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options) {
54-
BertNLClassifierOptions proto_options =
55-
ConvertJavaBertNLClassifierOptions(env, java_options);
63+
JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
64+
jlong base_options_handle) {
65+
BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions(
66+
env, java_options, base_options_handle);
5667
proto_options.mutable_base_options()->mutable_model_file()->set_file_content(
5768
static_cast<char*>(env->GetDirectBufferAddress(model_buffer)),
5869
static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer)));
@@ -71,9 +82,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByte
7182

7283
extern "C" JNIEXPORT jlong JNICALL
7384
Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor(
74-
JNIEnv* env, jclass thiz, jint fd, jobject java_options) {
75-
BertNLClassifierOptions proto_options =
76-
ConvertJavaBertNLClassifierOptions(env, java_options);
85+
JNIEnv* env, jclass thiz, jint fd, jobject java_options,
86+
jlong base_options_handle) {
87+
BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions(
88+
env, java_options, base_options_handle);
7789
proto_options.mutable_base_options()
7890
->mutable_model_file()
7991
->mutable_file_descriptor_meta()
@@ -96,5 +108,3 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative(
96108
JNIEnv* env, jclass clazz, jlong native_handle, jstring text) {
97109
return RunClassifier(env, native_handle, text);
98110
}
99-
100-
} // namespace

0 commit comments

Comments
 (0)