Skip to content

Commit 1f9bfc0

Browse files
lu-wang-gtflite-support-robot
authored andcommitted
Task delegate 6: Refactor the way to configure num_threads
In ImageClassifier, ImageClassifierOptions contains both BaseOptions and NumThreads (configured through the legacy API). Previously, NumThreads is merged into BaseOptions in the JNI layer of ImageClassifier. This new refactor merges NumThreads into BaseOptions in the Java layer, which simplifies the logic a little bit and helps to reduce complexity when adding BaseOptions into the other tasks. PiperOrigin-RevId: 400223492
1 parent 8e78765 commit 1f9bfc0

File tree

3 files changed

+17
-19
lines changed

3 files changed

+17
-19
lines changed

tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/core/TaskJniUtils.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ public static MappedByteBuffer loadMappedFile(Context context, String filePath)
147147
}
148148

149149
/**
150-
* Try load a native library, if it's already loaded return directly.
150+
* Try loading a native library, if it's already loaded return directly.
151151
*
152152
* @param libName name of the lib
153153
*/
@@ -162,8 +162,17 @@ public static void tryLoadLibrary(String libName) {
162162
}
163163

164164
public static long createProtoBaseOptionsHandle(BaseOptions baseOptions) {
165+
return createProtoBaseOptionsHandleWithLegacyNumThreads(baseOptions, /*legacyNumThreads =*/ -1);
166+
}
167+
168+
public static long createProtoBaseOptionsHandleWithLegacyNumThreads(
169+
BaseOptions baseOptions, int legacyNumThreads) {
170+
// NumThreads should be configured through BaseOptions. However, if NumThreads is configured
171+
// through the legacy API of the Task Java API (then it will not equal to -1, the default
172+
// value), use it to overide the one in baseOptions.
165173
return createProtoBaseOptions(
166-
baseOptions.getComputeSettings().getDelegate().getValue(), baseOptions.getNumThreads());
174+
baseOptions.getComputeSettings().getDelegate().getValue(),
175+
legacyNumThreads == -1 ? baseOptions.getNumThreads() : legacyNumThreads);
167176
}
168177

169178
private TaskJniUtils() {}

tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/classifier/ImageClassifier.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ public long createHandle(
144144
fileDescriptorLength,
145145
fileDescriptorOffset,
146146
options,
147-
TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
147+
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
148+
options.getBaseOptions(), options.getNumThreads()));
148149
}
149150
},
150151
IMAGE_CLASSIFIER_NATIVE_LIB,
@@ -175,7 +176,8 @@ public long createHandle() {
175176
/*fileDescriptorLength=*/ OPTIONAL_FD_LENGTH,
176177
/*fileDescriptorOffset=*/ OPTIONAL_FD_OFFSET,
177178
options,
178-
TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
179+
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
180+
options.getBaseOptions(), options.getNumThreads()));
179181
}
180182
},
181183
IMAGE_CLASSIFIER_NATIVE_LIB));
@@ -209,7 +211,8 @@ public long createHandle() {
209211
return initJniWithByteBuffer(
210212
modelBuffer,
211213
options,
212-
TaskJniUtils.createProtoBaseOptionsHandle(options.getBaseOptions()));
214+
TaskJniUtils.createProtoBaseOptionsHandleWithLegacyNumThreads(
215+
options.getBaseOptions(), options.getNumThreads()));
213216
}
214217
},
215218
IMAGE_CLASSIFIER_NATIVE_LIB));

tensorflow_lite_support/java/src/native/task/vision/classifier/image_classifier_jni.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,6 @@ ImageClassifierOptions ConvertToProtoOptions(JNIEnv* env, jobject java_options,
100100
for (const auto& class_name : deny_list_vector) {
101101
proto_options.add_class_name_blacklist(class_name);
102102
}
103-
104-
jmethodID num_threads_id =
105-
env->GetMethodID(java_options_class, "getNumThreads", "()I");
106-
jint num_threads = env->CallIntMethod(java_options, num_threads_id);
107-
// Use base_options to configure num_threads, because image_classifier is
108-
// created using base_options in initJniWithModelFdAndOptions and
109-
// initJniWithByteBuffer.
110-
if (num_threads != -1) {
111-
proto_options.mutable_base_options()
112-
->mutable_compute_settings()
113-
->mutable_tflite_settings()
114-
->mutable_cpu_settings()
115-
->set_num_threads(num_threads);
116-
}
117103
return proto_options;
118104
}
119105

0 commit comments

Comments
 (0)