Skip to content

Commit cb753ca

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Task delegate 7: Expose BaseOptions in Java ObjectDetector
PiperOrigin-RevId: 400229475
1 parent 1f9bfc0 commit cb753ca

File tree

3 files changed

+58
-20
lines changed

3 files changed

+58
-20
lines changed

tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/detector/ObjectDetector.java

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.tensorflow.lite.annotations.UsedByReflection;
2929
import org.tensorflow.lite.support.image.MlImageAdapter;
3030
import org.tensorflow.lite.support.image.TensorImage;
31+
import org.tensorflow.lite.task.core.BaseOptions;
3132
import org.tensorflow.lite.task.core.TaskJniUtils;
3233
import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
3334
import org.tensorflow.lite.task.core.TaskJniUtils.FdAndOptionsHandleProvider;
@@ -151,7 +152,12 @@ public long createHandle(
151152
long fileDescriptorOffset,
152153
ObjectDetectorOptions options) {
153154
return initJniWithModelFdAndOptions(
154-
fileDescriptor, fileDescriptorLength, fileDescriptorOffset, options);
155+
fileDescriptor,
156+
fileDescriptorLength,
157+
fileDescriptorOffset,
158+
options,
159+
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
160+
options.getBaseOptions(), options.getNumThreads()));
155161
}
156162
},
157163
OBJECT_DETECTOR_NATIVE_LIB,
@@ -180,7 +186,9 @@ public long createHandle() {
180186
descriptor.getFd(),
181187
/*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
182188
/*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
183-
options);
189+
options,
190+
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
191+
options.getBaseOptions(), options.getNumThreads()));
184192
}
185193
},
186194
OBJECT_DETECTOR_NATIVE_LIB));
@@ -209,7 +217,11 @@ public static ObjectDetector createFromBufferAndOptions(
209217
new EmptyHandleProvider() {
210218
@Override
211219
public long createHandle() {
212-
return initJniWithByteBuffer(modelBuffer, options);
220+
return initJniWithByteBuffer(
221+
modelBuffer,
222+
options,
223+
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
224+
options.getBaseOptions(), options.getNumThreads()));
213225
}
214226
},
215227
OBJECT_DETECTOR_NATIVE_LIB));
@@ -233,6 +245,7 @@ public static class ObjectDetectorOptions {
233245
// 1. java.util.Optional require Java 8 while we need to support Java 7.
234246
// 2. The Guava library (com.google.common.base.Optional) is avoided in this project. See the
235247
// comments for labelAllowList.
248+
private final BaseOptions baseOptions;
236249
private final String displayNamesLocale;
237250
private final int maxResults;
238251
private final float scoreThreshold;
@@ -252,6 +265,7 @@ public static Builder builder() {
252265

253266
/** A builder that helps to configure an instance of ObjectDetectorOptions. */
254267
public static class Builder {
268+
private BaseOptions baseOptions = BaseOptions.builder().build();
255269
private String displayNamesLocale = "en";
256270
private int maxResults = -1;
257271
private float scoreThreshold;
@@ -262,6 +276,12 @@ public static class Builder {
262276

263277
private Builder() {}
264278

279+
/** Sets the general options to configure Task APIs, such as accelerators. */
280+
public Builder setBaseOptions(BaseOptions baseOptions) {
281+
this.baseOptions = baseOptions;
282+
return this;
283+
}
284+
265285
/**
266286
* Sets the locale to use for display names specified through the TFLite Model Metadata, if
267287
* any.
@@ -335,7 +355,11 @@ public Builder setLabelDenyList(List<String> labelDenyList) {
335355
*
336356
* <p>numThreads should be greater than 0 or equal to -1. Setting numThreads to -1 has the
337357
* effect to let TFLite runtime set the value.
358+
*
359+
* @deprecated use {@link BaseOptions} to configure number of threads instead. This method
360+
* will override the number of threads configured from {@link BaseOptions}.
338361
*/
362+
@Deprecated
339363
public Builder setNumThreads(int numThreads) {
340364
this.numThreads = numThreads;
341365
return this;
@@ -381,6 +405,10 @@ public int getNumThreads() {
381405
return numThreads;
382406
}
383407

408+
public BaseOptions getBaseOptions() {
409+
return baseOptions;
410+
}
411+
384412
private ObjectDetectorOptions(Builder builder) {
385413
displayNamesLocale = builder.displayNamesLocale;
386414
maxResults = builder.maxResults;
@@ -389,6 +417,7 @@ private ObjectDetectorOptions(Builder builder) {
389417
labelAllowList = builder.labelAllowList;
390418
labelDenyList = builder.labelDenyList;
391419
numThreads = builder.numThreads;
420+
baseOptions = builder.baseOptions;
392421
}
393422
}
394423

@@ -463,8 +492,7 @@ public List<Detection> detect(MlImage image) {
463492
}
464493

465494
/**
466-
* Performs actual detection on the provided {@code MlImage} with {@link
467-
* ImageProcessingOptions}.
495+
* Performs actual detection on the provided {@code MlImage} with {@link ImageProcessingOptions}.
468496
*
469497
* <p>{@link ObjectDetector} supports the following options:
470498
*
@@ -497,10 +525,11 @@ private static native long initJniWithModelFdAndOptions(
497525
int fileDescriptor,
498526
long fileDescriptorLength,
499527
long fileDescriptorOffset,
500-
ObjectDetectorOptions options);
528+
ObjectDetectorOptions options,
529+
long baseOptionsHandle);
501530

502531
private static native long initJniWithByteBuffer(
503-
ByteBuffer modelBuffer, ObjectDetectorOptions options);
532+
ByteBuffer modelBuffer, ObjectDetectorOptions options, long baseOptionsHandle);
504533

505534
private static native List<Detection> detectNative(long nativeHandle, long frameBufferHandle);
506535

tensorflow_lite_support/java/src/native/task/vision/detector/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ jni_binary_with_tflite(
1818
name = "libtask_vision_jni.so",
1919
srcs = [
2020
"object_detector_jni.cc",
21+
"//tensorflow_lite_support/java/src/native/task/core:task_jni_utils.cc",
2122
"//tensorflow_lite_support/java/src/native/task/vision/core:base_vision_task_api_jni.cc",
2223
],
2324
linkscript = "//tensorflow_lite_support/java:default_version_script.lds",
@@ -28,6 +29,7 @@ jni_binary_with_tflite(
2829
],
2930
deps = [
3031
"//tensorflow_lite_support/cc/port:statusor",
32+
"//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc",
3133
"//tensorflow_lite_support/cc/task/vision/core:frame_buffer",
3234
"//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc",
3335
"//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc",

tensorflow_lite_support/java/src/native/task/vision/detector/object_detector_jni.cc

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
#include "external/com_google_absl/absl/strings/string_view.h"
2222
#include "tensorflow_lite_support/cc/port/statusor.h"
23+
#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h"
2324
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
2425
#include "tensorflow_lite_support/cc/task/vision/object_detector.h"
2526
#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
@@ -36,16 +37,26 @@ using ::tflite::support::utils::kAssertionError;
3637
using ::tflite::support::utils::kInvalidPointer;
3738
using ::tflite::support::utils::StringListToVector;
3839
using ::tflite::support::utils::ThrowException;
40+
using ::tflite::task::core::BaseOptions;
3941
using ::tflite::task::vision::BoundingBox;
4042
using ::tflite::task::vision::ConvertToCategory;
4143
using ::tflite::task::vision::DetectionResult;
4244
using ::tflite::task::vision::FrameBuffer;
45+
4346
using ::tflite::task::vision::ObjectDetector;
4447
using ::tflite::task::vision::ObjectDetectorOptions;
4548

4649
// Creates an ObjectDetectorOptions proto based on the Java class.
47-
ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options) {
50+
ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
51+
jlong base_options_handle) {
4852
ObjectDetectorOptions proto_options;
53+
54+
if (base_options_handle != kInvalidPointer) {
55+
// proto_options will free the previous base_options and set the new one.
56+
proto_options.set_allocated_base_options(
57+
reinterpret_cast<BaseOptions*>(base_options_handle));
58+
}
59+
4960
jclass java_options_class = env->FindClass(
5061
"org/tensorflow/lite/task/vision/detector/"
5162
"ObjectDetector$ObjectDetectorOptions");
@@ -91,12 +102,6 @@ ObjectDetectorOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options) {
91102
for (const auto& class_name : deny_list_vector) {
92103
proto_options.add_class_name_blacklist(class_name);
93104
}
94-
95-
jmethodID num_threads_id =
96-
env->GetMethodID(java_options_class, "getNumThreads", "()I");
97-
jint num_threads = env->CallIntMethod(java_options, num_threads_id);
98-
proto_options.set_num_threads(num_threads);
99-
100105
return proto_options;
101106
}
102107

@@ -177,10 +182,11 @@ extern "C" JNIEXPORT jlong JNICALL
177182
Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdAndOptions(
178183
JNIEnv* env, jclass thiz, jint file_descriptor,
179184
jlong file_descriptor_length, jlong file_descriptor_offset,
180-
jobject java_options) {
185+
jobject java_options, jlong base_options_handle) {
181186
ObjectDetectorOptions proto_options =
182-
ConvertToProtoOptions(env, java_options);
183-
auto file_descriptor_meta = proto_options.mutable_model_file_with_metadata()
187+
ConvertToProtoOptions(env, java_options, base_options_handle);
188+
auto file_descriptor_meta = proto_options.mutable_base_options()
189+
->mutable_model_file()
184190
->mutable_file_descriptor_meta();
185191
file_descriptor_meta->set_fd(file_descriptor);
186192
if (file_descriptor_length > 0) {
@@ -194,10 +200,11 @@ Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithModelFdA
194200

195201
extern "C" JNIEXPORT jlong JNICALL
196202
Java_org_tensorflow_lite_task_vision_detector_ObjectDetector_initJniWithByteBuffer(
197-
JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options) {
203+
JNIEnv* env, jclass thiz, jobject model_buffer, jobject java_options,
204+
jlong base_options_handle) {
198205
ObjectDetectorOptions proto_options =
199-
ConvertToProtoOptions(env, java_options);
200-
proto_options.mutable_model_file_with_metadata()->set_file_content(
206+
ConvertToProtoOptions(env, java_options, base_options_handle);
207+
proto_options.mutable_base_options()->mutable_model_file()->set_file_content(
201208
static_cast<char*>(env->GetDirectBufferAddress(model_buffer)),
202209
static_cast<size_t>(env->GetDirectBufferCapacity(model_buffer)));
203210
return CreateObjectDetectorFromOptions(env, proto_options);

0 commit comments

Comments
 (0)