@@ -15,6 +15,7 @@ limitations under the License.
1515
1616#include < jni.h>
1717
18+ #include " tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h"
1819#include " tensorflow_lite_support/cc/task/text/bert_nl_classifier.h"
1920#include " tensorflow_lite_support/cc/task/text/proto/bert_nl_classifier_options_proto_inc.h"
2021#include " tensorflow_lite_support/cc/utils/jni_utils.h"
@@ -25,13 +26,21 @@ namespace {
2526using ::tflite::support::utils::kAssertionError ;
2627using ::tflite::support::utils::kInvalidPointer ;
2728using ::tflite::support::utils::ThrowException;
29+ using ::tflite::task::core::BaseOptions;
2830using ::tflite::task::text::BertNLClassifier;
2931using ::tflite::task::text::BertNLClassifierOptions;
3032using ::tflite::task::text::nlclassifier::RunClassifier;
3133
3234BertNLClassifierOptions ConvertJavaBertNLClassifierOptions (
33- JNIEnv* env, jobject java_options) {
35+ JNIEnv* env, jobject java_options, jlong base_options_handle ) {
3436 BertNLClassifierOptions proto_options;
37+
38+ if (base_options_handle != kInvalidPointer ) {
39+ // proto_options will free the previous base_options and set the new one.
40+ proto_options.set_allocated_base_options (
41+ reinterpret_cast <BaseOptions*>(base_options_handle));
42+ }
43+
3544 jclass java_options_class = env->FindClass (
3645 " org/tensorflow/lite/task/text/nlclassifier/"
3746 " BertNLClassifier$BertNLClassifierOptions" );
@@ -41,6 +50,7 @@ BertNLClassifierOptions ConvertJavaBertNLClassifierOptions(
4150 env->CallIntMethod (java_options, max_seq_len_id));
4251 return proto_options;
4352}
53+ } // namespace
4454
4555extern " C" JNIEXPORT void JNICALL
4656Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni (
@@ -50,9 +60,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_deinitJni(
5060
5161extern " C" JNIEXPORT jlong JNICALL
5262Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByteBuffer (
53- JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options) {
54- BertNLClassifierOptions proto_options =
55- ConvertJavaBertNLClassifierOptions (env, java_options);
63+ JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
64+ jlong base_options_handle) {
65+ BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions (
66+ env, java_options, base_options_handle);
5667 proto_options.mutable_base_options ()->mutable_model_file ()->set_file_content (
5768 static_cast <char *>(env->GetDirectBufferAddress (model_buffer)),
5869 static_cast <size_t >(env->GetDirectBufferCapacity (model_buffer)));
@@ -71,9 +82,10 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithByte
7182
7283extern " C" JNIEXPORT jlong JNICALL
7384Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_initJniWithFileDescriptor (
74- JNIEnv* env, jclass thiz, jint fd, jobject java_options) {
75- BertNLClassifierOptions proto_options =
76- ConvertJavaBertNLClassifierOptions (env, java_options);
85+ JNIEnv* env, jclass thiz, jint fd, jobject java_options,
86+ jlong base_options_handle) {
87+ BertNLClassifierOptions proto_options = ConvertJavaBertNLClassifierOptions (
88+ env, java_options, base_options_handle);
7789 proto_options.mutable_base_options ()
7890 ->mutable_model_file ()
7991 ->mutable_file_descriptor_meta ()
@@ -96,5 +108,3 @@ Java_org_tensorflow_lite_task_text_nlclassifier_BertNLClassifier_classifyNative(
96108 JNIEnv* env, jclass clazz, jlong native_handle, jstring text) {
97109 return RunClassifier (env, native_handle, text);
98110}
99-
100- } // namespace
0 commit comments