Skip to content

Commit e25cb4f

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Task delegate 8: Expose BaseOptions in Java ImageSegmenter
PiperOrigin-RevId: 400230011
1 parent cb753ca commit e25cb4f

File tree

3 files changed

+36
-14
lines changed

3 files changed

+36
-14
lines changed

tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/ImageSegmenter.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.List;
3131
import org.tensorflow.lite.support.image.MlImageAdapter;
3232
import org.tensorflow.lite.support.image.TensorImage;
33+
import org.tensorflow.lite.task.core.BaseOptions;
3334
import org.tensorflow.lite.task.core.TaskJniUtils;
3435
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
3536
import org.tensorflow.lite.task.core.vision.ImageProcessingOptions;
@@ -186,7 +187,8 @@ public long createHandle() {
186187
modelBuffer,
187188
options.getDisplayNamesLocale(),
188189
options.getOutputType().getValue(),
189-
options.getNumThreads());
190+
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
191+
options.getBaseOptions(), options.getNumThreads()));
190192
}
191193
},
192194
IMAGE_SEGMENTER_NATIVE_LIB),
@@ -210,6 +212,8 @@ public abstract static class ImageSegmenterOptions {
210212
private static final OutputType DEFAULT_OUTPUT_TYPE = OutputType.CATEGORY_MASK;
211213
private static final int NUM_THREADS = -1;
212214

215+
public abstract BaseOptions getBaseOptions();
216+
213217
public abstract String getDisplayNamesLocale();
214218

215219
public abstract OutputType getOutputType();
@@ -220,13 +224,17 @@ public static Builder builder() {
220224
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
221225
.setDisplayNamesLocale(DEFAULT_DISPLAY_NAME_LOCALE)
222226
.setOutputType(DEFAULT_OUTPUT_TYPE)
223-
.setNumThreads(NUM_THREADS);
227+
.setNumThreads(NUM_THREADS)
228+
.setBaseOptions(BaseOptions.builder().build());
224229
}
225230

226231
/** Builder for {@link ImageSegmenterOptions}. */
227232
@AutoValue.Builder
228233
public abstract static class Builder {
229234

235+
/** Sets the general options to configure Task APIs, such as accelerators. */
236+
public abstract Builder setBaseOptions(BaseOptions baseOptions);
237+
230238
/**
231239
* Sets the locale to use for display names specified through the TFLite Model Metadata, if
232240
* any.
@@ -245,7 +253,11 @@ public abstract static class Builder {
245253
*
246254
* <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
247255
* effect to let TFLite runtime set the value.
256+
*
257+
* @deprecated use {@link BaseOptions} to configure number of threads instead. This method
258+
* will override the number of threads configured from {@link BaseOptions}.
248259
*/
260+
@Deprecated
249261
public abstract Builder setNumThreads(int numThreads);
250262

251263
public abstract ImageSegmenterOptions build();
@@ -402,7 +414,8 @@ public long createHandle() {
402414
fileDescriptorOffset,
403415
options.getDisplayNamesLocale(),
404416
options.getOutputType().getValue(),
405-
options.getNumThreads());
417+
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
418+
options.getBaseOptions(), options.getNumThreads()));
406419
}
407420
},
408421
IMAGE_SEGMENTER_NATIVE_LIB);
@@ -415,10 +428,10 @@ private static native long initJniWithModelFdAndOptions(
415428
long fileDescriptorOffset,
416429
String displayNamesLocale,
417430
int outputType,
418-
int numThreads);
431+
long baseOptionsHandle);
419432

420433
private static native long initJniWithByteBuffer(
421-
ByteBuffer modelBuffer, String displayNamesLocale, int outputType, int numThreads);
434+
ByteBuffer modelBuffer, String displayNamesLocale, int outputType, long baseOptionsHandle);
422435

423436
/**
424437
* The native method to segment the image.

tensorflow_lite_support/java/src/native/task/vision/segmenter/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jni_binary_with_tflite(
1818
name = "libtask_vision_jni.so",
1919
srcs = [
2020
"image_segmenter_jni.cc",
21+
"//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc",
2122
"//tensorflow_lite_support/java/src/native/task/vision/core:base_vision_task_api_jni.cc",
2223
],
2324
linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
@@ -28,6 +29,7 @@ jni_binary_with_tflite(
2829
],
2930
deps = [
3031
"//tensorflow_lite_support/cc/port:statusor",
32+
"//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc",
3133
"//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
3234
"//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc",
3335
"//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc",

tensorflow_lite_support/java/src/native/task/vision/segmenter/image_segmenter_jni.cc

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include "external/com_google_absl/absl/strings/str_cat.h"
2222
#include "tensorflow_lite_support/cc/port/statusor.h"
23+
#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h"
2324
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
2425
#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h"
2526
#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h"
@@ -36,6 +37,7 @@ using ::tflite::support::utils::kAssertionError;
3637
using ::tflite::support::utils::kIllegalArgumentException;
3738
using ::tflite::support::utils::kInvalidPointer;
3839
using ::tflite::support::utils::ThrowException;
40+
using ::tflite::task::core::BaseOptions;
3941
using ::tflite::task::vision::FrameBuffer;
4042
using ::tflite::task::vision::ImageSegmenter;
4143
using ::tflite::task::vision::ImageSegmenterOptions;
@@ -58,9 +60,15 @@ constexpr int kOutputTypeConfidenceMask = 1;
5860
ImageSegmenterOptions ConvertToProtoOptions(JNIEnv* env,
5961
jstring display_names_locale,
6062
jint output_type,
61-
jint num_threads) {
63+
jlong base_options_handle) {
6264
ImageSegmenterOptions proto_options;
6365

66+
if (base_options_handle != kInvalidPointer) {
67+
// proto_options will free the previous base_options and set the new one.
68+
proto_options.set_allocated_base_options(
69+
reinterpret_cast<BaseOptions*>(base_options_handle));
70+
}
71+
6472
const char* pchars = env->GetStringUTFChars(display_names_locale, nullptr);
6573
proto_options.set_display_names_locale(pchars);
6674
env->ReleaseStringUTFChars(display_names_locale, pchars);
@@ -78,8 +86,6 @@ ImageSegmenterOptions ConvertToProtoOptions(JNIEnv* env,
7886
"Unsupported output type: %d", output_type);
7987
}
8088

81-
proto_options.set_num_threads(num_threads);
82-
8389
return proto_options;
8490
}
8591

@@ -185,10 +191,11 @@ extern "C" JNIEXPORT jlong JNICALL
185191
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFdAndOptions(
186192
JNIEnv* env, jclass thiz, jint file_descriptor,
187193
jlong file_descriptor_length, jlong file_descriptor_offset,
188-
jstring display_names_locale, jint output_type, jint num_threads) {
194+
jstring display_names_locale, jint output_type, jlong base_options_handle) {
189195
ImageSegmenterOptions proto_options = ConvertToProtoOptions(
190-
env, display_names_locale, output_type, num_threads);
191-
auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata()
196+
env, display_names_locale, output_type, base_options_handle);
197+
auto file_descriptor_meta = proto_options.mutable_base_options()
198+
->mutable_model_file()
192199
->mutable_file_descriptor_meta();
193200
file_descriptor_meta->set_fd(file_descriptor);
194201
if (file_descriptor_length > 0) {
@@ -203,10 +210,10 @@ Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithModelFd
203210
extern "C" JNIEXPORT jlong JNICALL
204211
Java_org_tensorflow_lite_task_vision_segmenter_ImageSegmenter_initJniWithByteBuffer(
205212
JNIEnv* env, jclass thiz, jobject model_buffer,
206-
jstring display_names_locale, jint output_type, jint num_threads) {
213+
jstring display_names_locale, jint output_type, jlong base_options_handle) {
207214
ImageSegmenterOptions proto_options = ConvertToProtoOptions(
208-
env, display_names_locale, output_type, num_threads);
209-
proto_options.mutable_model_file_with_metadata()->set_file_content(
215+
env, display_names_locale, output_type, base_options_handle);
216+
proto_options.mutable_base_options()->mutable_model_file()->set_file_content(
210217
static_cast<char*>(env->GetDirectBufferAddress(model_buffer)),
211218
static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer)));
212219
return CreateImageSegmenterFromOptions(env, proto_options);

0 commit comments

Comments
 (0)